diff --git a/datavec/datavec-data/datavec-data-audio/pom.xml b/datavec/datavec-data/datavec-data-audio/pom.xml deleted file mode 100644 index d667b13ac..000000000 --- a/datavec/datavec-data/datavec-data-audio/pom.xml +++ /dev/null @@ -1,77 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-data - 1.0.0-SNAPSHOT - - - datavec-data-audio - - datavec-data-audio - - - - org.datavec - datavec-api - - - org.bytedeco - javacpp - ${javacpp.version} - - - org.bytedeco - javacv - ${javacv.version} - - - com.github.wendykierp - JTransforms - ${jtransforms.version} - with-dependencies - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/Wave.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/Wave.java deleted file mode 100644 index 7071dfc70..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/Wave.java +++ /dev/null @@ -1,329 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio; - - -import org.datavec.audio.extension.NormalizedSampleAmplitudes; -import org.datavec.audio.extension.Spectrogram; -import org.datavec.audio.fingerprint.FingerprintManager; -import org.datavec.audio.fingerprint.FingerprintSimilarity; -import org.datavec.audio.fingerprint.FingerprintSimilarityComputer; - -import java.io.FileInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.io.Serializable; - -public class Wave implements Serializable { - - private static final long serialVersionUID = 1L; - private WaveHeader waveHeader; - private byte[] data; // little endian - private byte[] fingerprint; - - /** - * Constructor - * - */ - public Wave() { - this.waveHeader = new WaveHeader(); - this.data = new byte[0]; - } - - /** - * Constructor - * - * @param filename - * Wave file - */ - public Wave(String filename) { - try { - InputStream inputStream = new FileInputStream(filename); - initWaveWithInputStream(inputStream); - inputStream.close(); - } catch (IOException e) { - System.out.println(e.toString()); - } - } - - /** - * Constructor - * - * @param inputStream - * Wave file input stream - */ - public Wave(InputStream inputStream) { - initWaveWithInputStream(inputStream); - } - - /** - * Constructor - * - * @param waveHeader - * @param data - */ - public Wave(WaveHeader waveHeader, byte[] data) { - this.waveHeader = waveHeader; - this.data = data; - } - - private void initWaveWithInputStream(InputStream inputStream) { - // reads the first 44 bytes for header - waveHeader = new WaveHeader(inputStream); - - if (waveHeader.isValid()) { - // load data - try { - data = new byte[inputStream.available()]; - inputStream.read(data); - } catch (IOException e) { - System.err.println(e.toString()); - } - // end load data - } else { - System.err.println("Invalid Wave Header"); - } - } - - /** - * Trim the wave data - * - * @param leftTrimNumberOfSample - * Number of sample trimmed from beginning - * @param rightTrimNumberOfSample - * Number of sample trimmed from ending - */ - public void trim(int leftTrimNumberOfSample, int rightTrimNumberOfSample) { - - long chunkSize = waveHeader.getChunkSize(); - long subChunk2Size = waveHeader.getSubChunk2Size(); - - long totalTrimmed = leftTrimNumberOfSample + rightTrimNumberOfSample; - - if (totalTrimmed > subChunk2Size) { - leftTrimNumberOfSample = (int) subChunk2Size; - } - - // update wav info - chunkSize -= totalTrimmed; - subChunk2Size -= totalTrimmed; - - if (chunkSize >= 0 && subChunk2Size >= 0) { - waveHeader.setChunkSize(chunkSize); - waveHeader.setSubChunk2Size(subChunk2Size); - - byte[] trimmedData = new byte[(int) subChunk2Size]; - System.arraycopy(data, (int) leftTrimNumberOfSample, trimmedData, 0, (int) subChunk2Size); - data = trimmedData; - } else { - System.err.println("Trim error: Negative length"); - } - } - - /** - * Trim the wave data from beginning - * - * @param numberOfSample - * numberOfSample trimmed from beginning - */ - public void leftTrim(int numberOfSample) { - trim(numberOfSample, 0); - } - - /** - * Trim the wave data from ending - * - * @param numberOfSample - * numberOfSample trimmed from ending - */ - public void rightTrim(int numberOfSample) { - trim(0, numberOfSample); - } - - /** - * Trim the wave data - * - * @param leftTrimSecond - * Seconds trimmed from beginning - * @param rightTrimSecond - * Seconds trimmed from ending - */ - public void trim(double leftTrimSecond, double rightTrimSecond) { - - int sampleRate = waveHeader.getSampleRate(); - int bitsPerSample = waveHeader.getBitsPerSample(); - int channels = waveHeader.getChannels(); - - int leftTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * leftTrimSecond); - int rightTrimNumberOfSample = (int) (sampleRate * bitsPerSample / 8 * channels * rightTrimSecond); - - trim(leftTrimNumberOfSample, rightTrimNumberOfSample); - } - - /** - * Trim the wave data from beginning - * - * @param second - * Seconds trimmed from beginning - */ - public void leftTrim(double second) { - trim(second, 0); - } - - /** - * Trim the wave data from ending - * - * @param second - * Seconds trimmed from ending - */ - public void rightTrim(double second) { - trim(0, second); - } - - /** - * Get the wave header - * - * @return waveHeader - */ - public WaveHeader getWaveHeader() { - return waveHeader; - } - - /** - * Get the wave spectrogram - * - * @return spectrogram - */ - public Spectrogram getSpectrogram() { - return new Spectrogram(this); - } - - /** - * Get the wave spectrogram - * - * @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2 - * @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping - * - * @return spectrogram - */ - public Spectrogram getSpectrogram(int fftSampleSize, int overlapFactor) { - return new Spectrogram(this, fftSampleSize, overlapFactor); - } - - /** - * Get the wave data in bytes - * - * @return wave data - */ - public byte[] getBytes() { - return data; - } - - /** - * Data byte size of the wave excluding header size - * - * @return byte size of the wave - */ - public int size() { - return data.length; - } - - /** - * Length of the wave in second - * - * @return length in second - */ - public float length() { - return (float) waveHeader.getSubChunk2Size() / waveHeader.getByteRate(); - } - - /** - * Timestamp of the wave length - * - * @return timestamp - */ - public String timestamp() { - float totalSeconds = this.length(); - float second = totalSeconds % 60; - int minute = (int) totalSeconds / 60 % 60; - int hour = (int) (totalSeconds / 3600); - - StringBuilder sb = new StringBuilder(); - if (hour > 0) { - sb.append(hour + ":"); - } - if (minute > 0) { - sb.append(minute + ":"); - } - sb.append(second); - - return sb.toString(); - } - - /** - * Get the amplitudes of the wave samples (depends on the header) - * - * @return amplitudes array (signed 16-bit) - */ - public short[] getSampleAmplitudes() { - int bytePerSample = waveHeader.getBitsPerSample() / 8; - int numSamples = data.length / bytePerSample; - short[] amplitudes = new short[numSamples]; - - int pointer = 0; - for (int i = 0; i < numSamples; i++) { - short amplitude = 0; - for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { - // little endian - amplitude |= (short) ((data[pointer++] & 0xFF) << (byteNumber * 8)); - } - amplitudes[i] = amplitude; - } - - return amplitudes; - } - - public String toString() { - StringBuilder sb = new StringBuilder(waveHeader.toString()); - sb.append("\n"); - sb.append("length: " + timestamp()); - return sb.toString(); - } - - public double[] getNormalizedAmplitudes() { - NormalizedSampleAmplitudes amplitudes = new NormalizedSampleAmplitudes(this); - return amplitudes.getNormalizedAmplitudes(); - } - - public byte[] getFingerprint() { - if (fingerprint == null) { - FingerprintManager fingerprintManager = new FingerprintManager(); - fingerprint = fingerprintManager.extractFingerprint(this); - } - return fingerprint; - } - - public FingerprintSimilarity getFingerprintSimilarity(Wave wave) { - FingerprintSimilarityComputer fingerprintSimilarityComputer = - new FingerprintSimilarityComputer(this.getFingerprint(), wave.getFingerprint()); - return fingerprintSimilarityComputer.getFingerprintsSimilarity(); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveFileManager.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveFileManager.java deleted file mode 100644 index bb8d1bcf9..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveFileManager.java +++ /dev/null @@ -1,98 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio; - -import lombok.extern.slf4j.Slf4j; - -import java.io.FileOutputStream; -import java.io.IOException; - -@Slf4j -public class WaveFileManager { - - private Wave wave; - - public WaveFileManager() { - wave = new Wave(); - } - - public WaveFileManager(Wave wave) { - setWave(wave); - } - - /** - * Save the wave file - * - * @param filename - * filename to be saved - * - * @see Wave file saved - */ - public void saveWaveAsFile(String filename) { - - WaveHeader waveHeader = wave.getWaveHeader(); - - int byteRate = waveHeader.getByteRate(); - int audioFormat = waveHeader.getAudioFormat(); - int sampleRate = waveHeader.getSampleRate(); - int bitsPerSample = waveHeader.getBitsPerSample(); - int channels = waveHeader.getChannels(); - long chunkSize = waveHeader.getChunkSize(); - long subChunk1Size = waveHeader.getSubChunk1Size(); - long subChunk2Size = waveHeader.getSubChunk2Size(); - int blockAlign = waveHeader.getBlockAlign(); - - try { - FileOutputStream fos = new FileOutputStream(filename); - fos.write(WaveHeader.RIFF_HEADER.getBytes()); - // little endian - fos.write(new byte[] {(byte) (chunkSize), (byte) (chunkSize >> 8), (byte) (chunkSize >> 16), - (byte) (chunkSize >> 24)}); - fos.write(WaveHeader.WAVE_HEADER.getBytes()); - fos.write(WaveHeader.FMT_HEADER.getBytes()); - fos.write(new byte[] {(byte) (subChunk1Size), (byte) (subChunk1Size >> 8), (byte) (subChunk1Size >> 16), - (byte) (subChunk1Size >> 24)}); - fos.write(new byte[] {(byte) (audioFormat), (byte) (audioFormat >> 8)}); - fos.write(new byte[] {(byte) (channels), (byte) (channels >> 8)}); - fos.write(new byte[] {(byte) (sampleRate), (byte) (sampleRate >> 8), (byte) (sampleRate >> 16), - (byte) (sampleRate >> 24)}); - fos.write(new byte[] {(byte) (byteRate), (byte) (byteRate >> 8), (byte) (byteRate >> 16), - (byte) (byteRate >> 24)}); - fos.write(new byte[] {(byte) (blockAlign), (byte) (blockAlign >> 8)}); - fos.write(new byte[] {(byte) (bitsPerSample), (byte) (bitsPerSample >> 8)}); - fos.write(WaveHeader.DATA_HEADER.getBytes()); - fos.write(new byte[] {(byte) (subChunk2Size), (byte) (subChunk2Size >> 8), (byte) (subChunk2Size >> 16), - (byte) (subChunk2Size >> 24)}); - fos.write(wave.getBytes()); - fos.close(); - } catch (IOException e) { - log.error("",e); - } - } - - public Wave getWave() { - return wave; - } - - public void setWave(Wave wave) { - this.wave = wave; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java deleted file mode 100644 index fc2d09e88..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/WaveHeader.java +++ /dev/null @@ -1,281 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio; - -import lombok.extern.slf4j.Slf4j; - -import java.io.IOException; -import java.io.InputStream; - -@Slf4j -public class WaveHeader { - - public static final String RIFF_HEADER = "RIFF"; - public static final String WAVE_HEADER = "WAVE"; - public static final String FMT_HEADER = "fmt "; - public static final String DATA_HEADER = "data"; - public static final int HEADER_BYTE_LENGTH = 44; // 44 bytes for header - - private boolean valid; - private String chunkId; // 4 bytes - private long chunkSize; // unsigned 4 bytes, little endian - private String format; // 4 bytes - private String subChunk1Id; // 4 bytes - private long subChunk1Size; // unsigned 4 bytes, little endian - private int audioFormat; // unsigned 2 bytes, little endian - private int channels; // unsigned 2 bytes, little endian - private long sampleRate; // unsigned 4 bytes, little endian - private long byteRate; // unsigned 4 bytes, little endian - private int blockAlign; // unsigned 2 bytes, little endian - private int bitsPerSample; // unsigned 2 bytes, little endian - private String subChunk2Id; // 4 bytes - private long subChunk2Size; // unsigned 4 bytes, little endian - - public WaveHeader() { - // init a 8k 16bit mono wav - chunkSize = 36; - subChunk1Size = 16; - audioFormat = 1; - channels = 1; - sampleRate = 8000; - byteRate = 16000; - blockAlign = 2; - bitsPerSample = 16; - subChunk2Size = 0; - valid = true; - } - - public WaveHeader(InputStream inputStream) { - valid = loadHeader(inputStream); - } - - private boolean loadHeader(InputStream inputStream) { - - byte[] headerBuffer = new byte[HEADER_BYTE_LENGTH]; - try { - inputStream.read(headerBuffer); - - // read header - int pointer = 0; - chunkId = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], - headerBuffer[pointer++]}); - // little endian - chunkSize = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 - | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff << 24); - format = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], headerBuffer[pointer++], - headerBuffer[pointer++]}); - subChunk1Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], - headerBuffer[pointer++], headerBuffer[pointer++]}); - subChunk1Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 - | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff) << 24; - audioFormat = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - channels = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - sampleRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 - | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff) << 24; - byteRate = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 - | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff) << 24; - blockAlign = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - bitsPerSample = (int) ((headerBuffer[pointer++] & 0xff) | (headerBuffer[pointer++] & 0xff) << 8); - subChunk2Id = new String(new byte[] {headerBuffer[pointer++], headerBuffer[pointer++], - headerBuffer[pointer++], headerBuffer[pointer++]}); - subChunk2Size = (long) (headerBuffer[pointer++] & 0xff) | (long) (headerBuffer[pointer++] & 0xff) << 8 - | (long) (headerBuffer[pointer++] & 0xff) << 16 - | (long) (headerBuffer[pointer++] & 0xff) << 24; - // end read header - - // the inputStream should be closed outside this method - - // dis.close(); - - } catch (IOException e) { - log.error("",e); - return false; - } - - if (bitsPerSample != 8 && bitsPerSample != 16) { - System.err.println("WaveHeader: only supports bitsPerSample 8 or 16"); - return false; - } - - // check the format is support - if (chunkId.toUpperCase().equals(RIFF_HEADER) && format.toUpperCase().equals(WAVE_HEADER) && audioFormat == 1) { - return true; - } else { - System.err.println("WaveHeader: Unsupported header format"); - } - - return false; - } - - public boolean isValid() { - return valid; - } - - public String getChunkId() { - return chunkId; - } - - public long getChunkSize() { - return chunkSize; - } - - public String getFormat() { - return format; - } - - public String getSubChunk1Id() { - return subChunk1Id; - } - - public long getSubChunk1Size() { - return subChunk1Size; - } - - public int getAudioFormat() { - return audioFormat; - } - - public int getChannels() { - return channels; - } - - public int getSampleRate() { - return (int) sampleRate; - } - - public int getByteRate() { - return (int) byteRate; - } - - public int getBlockAlign() { - return blockAlign; - } - - public int getBitsPerSample() { - return bitsPerSample; - } - - public String getSubChunk2Id() { - return subChunk2Id; - } - - public long getSubChunk2Size() { - return subChunk2Size; - } - - public void setSampleRate(int sampleRate) { - int newSubChunk2Size = (int) (this.subChunk2Size * sampleRate / this.sampleRate); - // if num bytes for each sample is even, the size of newSubChunk2Size also needed to be in even number - if ((bitsPerSample / 8) % 2 == 0) { - if (newSubChunk2Size % 2 != 0) { - newSubChunk2Size++; - } - } - - this.sampleRate = sampleRate; - this.byteRate = sampleRate * bitsPerSample / 8; - this.chunkSize = newSubChunk2Size + 36; - this.subChunk2Size = newSubChunk2Size; - } - - public void setChunkId(String chunkId) { - this.chunkId = chunkId; - } - - public void setChunkSize(long chunkSize) { - this.chunkSize = chunkSize; - } - - public void setFormat(String format) { - this.format = format; - } - - public void setSubChunk1Id(String subChunk1Id) { - this.subChunk1Id = subChunk1Id; - } - - public void setSubChunk1Size(long subChunk1Size) { - this.subChunk1Size = subChunk1Size; - } - - public void setAudioFormat(int audioFormat) { - this.audioFormat = audioFormat; - } - - public void setChannels(int channels) { - this.channels = channels; - } - - public void setByteRate(long byteRate) { - this.byteRate = byteRate; - } - - public void setBlockAlign(int blockAlign) { - this.blockAlign = blockAlign; - } - - public void setBitsPerSample(int bitsPerSample) { - this.bitsPerSample = bitsPerSample; - } - - public void setSubChunk2Id(String subChunk2Id) { - this.subChunk2Id = subChunk2Id; - } - - public void setSubChunk2Size(long subChunk2Size) { - this.subChunk2Size = subChunk2Size; - } - - public String toString() { - - StringBuilder sb = new StringBuilder(); - sb.append("chunkId: " + chunkId); - sb.append("\n"); - sb.append("chunkSize: " + chunkSize); - sb.append("\n"); - sb.append("format: " + format); - sb.append("\n"); - sb.append("subChunk1Id: " + subChunk1Id); - sb.append("\n"); - sb.append("subChunk1Size: " + subChunk1Size); - sb.append("\n"); - sb.append("audioFormat: " + audioFormat); - sb.append("\n"); - sb.append("channels: " + channels); - sb.append("\n"); - sb.append("sampleRate: " + sampleRate); - sb.append("\n"); - sb.append("byteRate: " + byteRate); - sb.append("\n"); - sb.append("blockAlign: " + blockAlign); - sb.append("\n"); - sb.append("bitsPerSample: " + bitsPerSample); - sb.append("\n"); - sb.append("subChunk2Id: " + subChunk2Id); - sb.append("\n"); - sb.append("subChunk2Size: " + subChunk2Size); - return sb.toString(); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/FastFourierTransform.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/FastFourierTransform.java deleted file mode 100644 index bc0bf3279..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/FastFourierTransform.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.dsp; - -import org.jtransforms.fft.DoubleFFT_1D; - -public class FastFourierTransform { - - /** - * Get the frequency intensities - * - * @param amplitudes amplitudes of the signal. Format depends on value of complex - * @param complex if true, amplitudes is assumed to be complex interlaced (re = even, im = odd), if false amplitudes - * are assumed to be real valued. - * @return intensities of each frequency unit: mag[frequency_unit]=intensity - */ - public double[] getMagnitudes(double[] amplitudes, boolean complex) { - - final int sampleSize = amplitudes.length; - final int nrofFrequencyBins = sampleSize / 2; - - - // call the fft and transform the complex numbers - if (complex) { - DoubleFFT_1D fft = new DoubleFFT_1D(nrofFrequencyBins); - fft.complexForward(amplitudes); - } else { - DoubleFFT_1D fft = new DoubleFFT_1D(sampleSize); - fft.realForward(amplitudes); - // amplitudes[1] contains re[sampleSize/2] or im[(sampleSize-1) / 2] (depending on whether sampleSize is odd or even) - // Discard it as it is useless without the other part - // im part dc bin is always 0 for real input - amplitudes[1] = 0; - } - // end call the fft and transform the complex numbers - - // even indexes (0,2,4,6,...) are real parts - // odd indexes (1,3,5,7,...) are img parts - double[] mag = new double[nrofFrequencyBins]; - for (int i = 0; i < nrofFrequencyBins; i++) { - final int f = 2 * i; - mag[i] = Math.sqrt(amplitudes[f] * amplitudes[f] + amplitudes[f + 1] * amplitudes[f + 1]); - } - - return mag; - } - - /** - * Get the frequency intensities. Backwards compatible with previous versions w.r.t to number of frequency bins. - * Use getMagnitudes(amplitudes, true) to get all bins. - * - * @param amplitudes complex-valued signal to transform. Even indexes are real and odd indexes are img - * @return intensities of each frequency unit: mag[frequency_unit]=intensity - */ - public double[] getMagnitudes(double[] amplitudes) { - double[] magnitudes = getMagnitudes(amplitudes, true); - - double[] halfOfMagnitudes = new double[magnitudes.length/2]; - System.arraycopy(magnitudes, 0,halfOfMagnitudes, 0, halfOfMagnitudes.length); - return halfOfMagnitudes; - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/LinearInterpolation.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/LinearInterpolation.java deleted file mode 100644 index 1595c8974..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/LinearInterpolation.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.dsp; - -public class LinearInterpolation { - - public LinearInterpolation() { - - } - - /** - * Do interpolation on the samples according to the original and destinated sample rates - * - * @param oldSampleRate sample rate of the original samples - * @param newSampleRate sample rate of the interpolated samples - * @param samples original samples - * @return interpolated samples - */ - public short[] interpolate(int oldSampleRate, int newSampleRate, short[] samples) { - - if (oldSampleRate == newSampleRate) { - return samples; - } - - int newLength = Math.round(((float) samples.length / oldSampleRate * newSampleRate)); - float lengthMultiplier = (float) newLength / samples.length; - short[] interpolatedSamples = new short[newLength]; - - // interpolate the value by the linear equation y=mx+c - for (int i = 0; i < newLength; i++) { - - // get the nearest positions for the interpolated point - float currentPosition = i / lengthMultiplier; - int nearestLeftPosition = (int) currentPosition; - int nearestRightPosition = nearestLeftPosition + 1; - if (nearestRightPosition >= samples.length) { - nearestRightPosition = samples.length - 1; - } - - float slope = samples[nearestRightPosition] - samples[nearestLeftPosition]; // delta x is 1 - float positionFromLeft = currentPosition - nearestLeftPosition; - - interpolatedSamples[i] = (short) (slope * positionFromLeft + samples[nearestLeftPosition]); // y=mx+c - } - - return interpolatedSamples; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/Resampler.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/Resampler.java deleted file mode 100644 index c50e9a385..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/Resampler.java +++ /dev/null @@ -1,84 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.dsp; - -public class Resampler { - - public Resampler() {} - - /** - * Do resampling. Currently the amplitude is stored by short such that maximum bitsPerSample is 16 (bytePerSample is 2) - * - * @param sourceData The source data in bytes - * @param bitsPerSample How many bits represents one sample (currently supports max. bitsPerSample=16) - * @param sourceRate Sample rate of the source data - * @param targetRate Sample rate of the target data - * @return re-sampled data - */ - public byte[] reSample(byte[] sourceData, int bitsPerSample, int sourceRate, int targetRate) { - - // make the bytes to amplitudes first - int bytePerSample = bitsPerSample / 8; - int numSamples = sourceData.length / bytePerSample; - short[] amplitudes = new short[numSamples]; // 16 bit, use a short to store - - int pointer = 0; - for (int i = 0; i < numSamples; i++) { - short amplitude = 0; - for (int byteNumber = 0; byteNumber < bytePerSample; byteNumber++) { - // little endian - amplitude |= (short) ((sourceData[pointer++] & 0xFF) << (byteNumber * 8)); - } - amplitudes[i] = amplitude; - } - // end make the amplitudes - - // do interpolation - LinearInterpolation reSample = new LinearInterpolation(); - short[] targetSample = reSample.interpolate(sourceRate, targetRate, amplitudes); - int targetLength = targetSample.length; - // end do interpolation - - // TODO: Remove the high frequency signals with a digital filter, leaving a signal containing only half-sample-rated frequency information, but still sampled at a rate of target sample rate. Usually FIR is used - - // end resample the amplitudes - - // convert the amplitude to bytes - byte[] bytes; - if (bytePerSample == 1) { - bytes = new byte[targetLength]; - for (int i = 0; i < targetLength; i++) { - bytes[i] = (byte) targetSample[i]; - } - } else { - // suppose bytePerSample==2 - bytes = new byte[targetLength * 2]; - for (int i = 0; i < targetSample.length; i++) { - // little endian - bytes[i * 2] = (byte) (targetSample[i] & 0xff); - bytes[i * 2 + 1] = (byte) ((targetSample[i] >> 8) & 0xff); - } - } - // end convert the amplitude to bytes - - return bytes; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java deleted file mode 100644 index b3783df81..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/WindowFunction.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.dsp; - -public class WindowFunction { - - public static final int RECTANGULAR = 0; - public static final int BARTLETT = 1; - public static final int HANNING = 2; - public static final int HAMMING = 3; - public static final int BLACKMAN = 4; - - int windowType = 0; // defaults to rectangular window - - public WindowFunction() {} - - public void setWindowType(int wt) { - windowType = wt; - } - - public void setWindowType(String w) { - if (w.toUpperCase().equals("RECTANGULAR")) - windowType = RECTANGULAR; - if (w.toUpperCase().equals("BARTLETT")) - windowType = BARTLETT; - if (w.toUpperCase().equals("HANNING")) - windowType = HANNING; - if (w.toUpperCase().equals("HAMMING")) - windowType = HAMMING; - if (w.toUpperCase().equals("BLACKMAN")) - windowType = BLACKMAN; - } - - public int getWindowType() { - return windowType; - } - - /** - * Generate a window - * - * @param nSamples size of the window - * @return window in array - */ - public double[] generate(int nSamples) { - // generate nSamples window function values - // for index values 0 .. nSamples - 1 - int m = nSamples / 2; - double r; - double pi = Math.PI; - double[] w = new double[nSamples]; - switch (windowType) { - case BARTLETT: // Bartlett (triangular) window - for (int n = 0; n < nSamples; n++) - w[n] = 1.0f - Math.abs(n - m) / m; - break; - case HANNING: // Hanning window - r = pi / (m + 1); - for (int n = -m; n < m; n++) - w[m + n] = 0.5f + 0.5f * Math.cos(n * r); - break; - case HAMMING: // Hamming window - r = pi / m; - for (int n = -m; n < m; n++) - w[m + n] = 0.54f + 0.46f * Math.cos(n * r); - break; - case BLACKMAN: // Blackman window - r = pi / m; - for (int n = -m; n < m; n++) - w[m + n] = 0.42f + 0.5f * Math.cos(n * r) + 0.08f * Math.cos(2 * n * r); - break; - default: // Rectangular window function - for (int n = 0; n < nSamples; n++) - w[n] = 1.0f; - } - return w; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/package-info.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/package-info.java deleted file mode 100644 index cf19c7831..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/dsp/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.dsp; diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java deleted file mode 100644 index cc1e7028f..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/NormalizedSampleAmplitudes.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.extension; - - -import org.datavec.audio.Wave; - -public class NormalizedSampleAmplitudes { - - private Wave wave; - private double[] normalizedAmplitudes; // normalizedAmplitudes[sampleNumber]=normalizedAmplitudeInTheFrame - - public NormalizedSampleAmplitudes(Wave wave) { - this.wave = wave; - } - - /** - * - * Get normalized amplitude of each frame - * - * @return array of normalized amplitudes(signed 16 bit): normalizedAmplitudes[frame]=amplitude - */ - public double[] getNormalizedAmplitudes() { - - if (normalizedAmplitudes == null) { - - boolean signed = true; - - // usually 8bit is unsigned - if (wave.getWaveHeader().getBitsPerSample() == 8) { - signed = false; - } - - short[] amplitudes = wave.getSampleAmplitudes(); - int numSamples = amplitudes.length; - int maxAmplitude = 1 << (wave.getWaveHeader().getBitsPerSample() - 1); - - if (!signed) { // one more bit for unsigned value - maxAmplitude <<= 1; - } - - normalizedAmplitudes = new double[numSamples]; - for (int i = 0; i < numSamples; i++) { - normalizedAmplitudes[i] = (double) amplitudes[i] / maxAmplitude; - } - } - return normalizedAmplitudes; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java deleted file mode 100644 index 6ca5d84f3..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/extension/Spectrogram.java +++ /dev/null @@ -1,214 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.extension; - - -import org.datavec.audio.Wave; -import org.datavec.audio.dsp.FastFourierTransform; -import org.datavec.audio.dsp.WindowFunction; - -public class Spectrogram { - - public static final int SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE = 1024; - public static final int SPECTROGRAM_DEFAULT_OVERLAP_FACTOR = 0; // 0 for no overlapping - - private Wave wave; - private double[][] spectrogram; // relative spectrogram - private double[][] absoluteSpectrogram; // absolute spectrogram - private int fftSampleSize; // number of sample in fft, the value needed to be a number to power of 2 - private int overlapFactor; // 1/overlapFactor overlapping, e.g. 1/4=25% overlapping - private int numFrames; // number of frames of the spectrogram - private int framesPerSecond; // frame per second of the spectrogram - private int numFrequencyUnit; // number of y-axis unit - private double unitFrequency; // frequency per y-axis unit - - /** - * Constructor - * - * @param wave - */ - public Spectrogram(Wave wave) { - this.wave = wave; - // default - this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; - this.overlapFactor = SPECTROGRAM_DEFAULT_OVERLAP_FACTOR; - buildSpectrogram(); - } - - /** - * Constructor - * - * @param wave - * @param fftSampleSize number of sample in fft, the value needed to be a number to power of 2 - * @param overlapFactor 1/overlapFactor overlapping, e.g. 1/4=25% overlapping, 0 for no overlapping - */ - public Spectrogram(Wave wave, int fftSampleSize, int overlapFactor) { - this.wave = wave; - - if (Integer.bitCount(fftSampleSize) == 1) { - this.fftSampleSize = fftSampleSize; - } else { - System.err.print("The input number must be a power of 2"); - this.fftSampleSize = SPECTROGRAM_DEFAULT_FFT_SAMPLE_SIZE; - } - - this.overlapFactor = overlapFactor; - - buildSpectrogram(); - } - - /** - * Build spectrogram - */ - private void buildSpectrogram() { - - short[] amplitudes = wave.getSampleAmplitudes(); - int numSamples = amplitudes.length; - - int pointer = 0; - // overlapping - if (overlapFactor > 1) { - int numOverlappedSamples = numSamples * overlapFactor; - int backSamples = fftSampleSize * (overlapFactor - 1) / overlapFactor; - short[] overlapAmp = new short[numOverlappedSamples]; - pointer = 0; - for (int i = 0; i < amplitudes.length; i++) { - overlapAmp[pointer++] = amplitudes[i]; - if (pointer % fftSampleSize == 0) { - // overlap - i -= backSamples; - } - } - numSamples = numOverlappedSamples; - amplitudes = overlapAmp; - } - // end overlapping - - numFrames = numSamples / fftSampleSize; - framesPerSecond = (int) (numFrames / wave.length()); - - // set signals for fft - WindowFunction window = new WindowFunction(); - window.setWindowType("Hamming"); - double[] win = window.generate(fftSampleSize); - - double[][] signals = new double[numFrames][]; - for (int f = 0; f < numFrames; f++) { - signals[f] = new double[fftSampleSize]; - int startSample = f * fftSampleSize; - for (int n = 0; n < fftSampleSize; n++) { - signals[f][n] = amplitudes[startSample + n] * win[n]; - } - } - // end set signals for fft - - absoluteSpectrogram = new double[numFrames][]; - // for each frame in signals, do fft on it - FastFourierTransform fft = new FastFourierTransform(); - for (int i = 0; i < numFrames; i++) { - absoluteSpectrogram[i] = fft.getMagnitudes(signals[i], false); - } - - if (absoluteSpectrogram.length > 0) { - - numFrequencyUnit = absoluteSpectrogram[0].length; - unitFrequency = (double) wave.getWaveHeader().getSampleRate() / 2 / numFrequencyUnit; // frequency could be caught within the half of nSamples according to Nyquist theory - - // normalization of absoultSpectrogram - spectrogram = new double[numFrames][numFrequencyUnit]; - - // set max and min amplitudes - double maxAmp = Double.MIN_VALUE; - double minAmp = Double.MAX_VALUE; - for (int i = 0; i < numFrames; i++) { - for (int j = 0; j < numFrequencyUnit; j++) { - if (absoluteSpectrogram[i][j] > maxAmp) { - maxAmp = absoluteSpectrogram[i][j]; - } else if (absoluteSpectrogram[i][j] < minAmp) { - minAmp = absoluteSpectrogram[i][j]; - } - } - } - // end set max and min amplitudes - - // normalization - // avoiding divided by zero - double minValidAmp = 0.00000000001F; - if (minAmp == 0) { - minAmp = minValidAmp; - } - - double diff = Math.log10(maxAmp / minAmp); // perceptual difference - for (int i = 0; i < numFrames; i++) { - for (int j = 0; j < numFrequencyUnit; j++) { - if (absoluteSpectrogram[i][j] < minValidAmp) { - spectrogram[i][j] = 0; - } else { - spectrogram[i][j] = (Math.log10(absoluteSpectrogram[i][j] / minAmp)) / diff; - } - } - } - // end normalization - } - } - - /** - * Get spectrogram: spectrogram[time][frequency]=intensity - * - * @return logarithm normalized spectrogram - */ - public double[][] getNormalizedSpectrogramData() { - return spectrogram; - } - - /** - * Get spectrogram: spectrogram[time][frequency]=intensity - * - * @return absolute spectrogram - */ - public double[][] getAbsoluteSpectrogramData() { - return absoluteSpectrogram; - } - - public int getNumFrames() { - return numFrames; - } - - public int getFramesPerSecond() { - return framesPerSecond; - } - - public int getNumFrequencyUnit() { - return numFrequencyUnit; - } - - public double getUnitFrequency() { - return unitFrequency; - } - - public int getFftSampleSize() { - return fftSampleSize; - } - - public int getOverlapFactor() { - return overlapFactor; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java deleted file mode 100644 index d6b150348..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintManager.java +++ /dev/null @@ -1,272 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - - -import lombok.extern.slf4j.Slf4j; -import org.datavec.audio.Wave; -import org.datavec.audio.WaveHeader; -import org.datavec.audio.dsp.Resampler; -import org.datavec.audio.extension.Spectrogram; -import org.datavec.audio.processor.TopManyPointsProcessorChain; -import org.datavec.audio.properties.FingerprintProperties; - -import java.io.FileInputStream; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.InputStream; -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; - -@Slf4j -public class FingerprintManager { - - private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); - private int sampleSizePerFrame = fingerprintProperties.getSampleSizePerFrame(); - private int overlapFactor = fingerprintProperties.getOverlapFactor(); - private int numRobustPointsPerFrame = fingerprintProperties.getNumRobustPointsPerFrame(); - private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); - - /** - * Constructor - */ - public FingerprintManager() { - - } - - /** - * Extract fingerprint from Wave object - * - * @param wave Wave Object to be extracted fingerprint - * @return fingerprint in bytes - */ - public byte[] extractFingerprint(Wave wave) { - - int[][] coordinates; // coordinates[x][0..3]=y0..y3 - byte[] fingerprint = new byte[0]; - - // resample to target rate - Resampler resampler = new Resampler(); - int sourceRate = wave.getWaveHeader().getSampleRate(); - int targetRate = fingerprintProperties.getSampleRate(); - - byte[] resampledWaveData = resampler.reSample(wave.getBytes(), wave.getWaveHeader().getBitsPerSample(), - sourceRate, targetRate); - - // update the wave header - WaveHeader resampledWaveHeader = wave.getWaveHeader(); - resampledWaveHeader.setSampleRate(targetRate); - - // make resampled wave - Wave resampledWave = new Wave(resampledWaveHeader, resampledWaveData); - // end resample to target rate - - // get spectrogram's data - Spectrogram spectrogram = resampledWave.getSpectrogram(sampleSizePerFrame, overlapFactor); - double[][] spectorgramData = spectrogram.getNormalizedSpectrogramData(); - - List[] pointsLists = getRobustPointList(spectorgramData); - int numFrames = pointsLists.length; - - // prepare fingerprint bytes - coordinates = new int[numFrames][numRobustPointsPerFrame]; - - for (int x = 0; x < numFrames; x++) { - if (pointsLists[x].size() == numRobustPointsPerFrame) { - Iterator pointsListsIterator = pointsLists[x].iterator(); - for (int y = 0; y < numRobustPointsPerFrame; y++) { - coordinates[x][y] = pointsListsIterator.next(); - } - } else { - // use -1 to fill the empty byte - for (int y = 0; y < numRobustPointsPerFrame; y++) { - coordinates[x][y] = -1; - } - } - } - // end make fingerprint - - // for each valid coordinate, append with its intensity - List byteList = new LinkedList(); - for (int i = 0; i < numFrames; i++) { - for (int j = 0; j < numRobustPointsPerFrame; j++) { - if (coordinates[i][j] != -1) { - // first 2 bytes is x - byteList.add((byte) (i >> 8)); - byteList.add((byte) i); - - // next 2 bytes is y - int y = coordinates[i][j]; - byteList.add((byte) (y >> 8)); - byteList.add((byte) y); - - // next 4 bytes is intensity - int intensity = (int) (spectorgramData[i][y] * Integer.MAX_VALUE); // spectorgramData is ranged from 0~1 - byteList.add((byte) (intensity >> 24)); - byteList.add((byte) (intensity >> 16)); - byteList.add((byte) (intensity >> 8)); - byteList.add((byte) intensity); - } - } - } - // end for each valid coordinate, append with its intensity - - fingerprint = new byte[byteList.size()]; - Iterator byteListIterator = byteList.iterator(); - int pointer = 0; - while (byteListIterator.hasNext()) { - fingerprint[pointer++] = byteListIterator.next(); - } - - return fingerprint; - } - - /** - * Get bytes from fingerprint file - * - * @param fingerprintFile fingerprint filename - * @return fingerprint in bytes - */ - public byte[] getFingerprintFromFile(String fingerprintFile) { - byte[] fingerprint = null; - try { - InputStream fis = new FileInputStream(fingerprintFile); - fingerprint = getFingerprintFromInputStream(fis); - fis.close(); - } catch (IOException e) { - log.error("",e); - } - return fingerprint; - } - - /** - * Get bytes from fingerprint inputstream - * - * @param inputStream fingerprint inputstream - * @return fingerprint in bytes - */ - public byte[] getFingerprintFromInputStream(InputStream inputStream) { - byte[] fingerprint = null; - try { - fingerprint = new byte[inputStream.available()]; - inputStream.read(fingerprint); - } catch (IOException e) { - log.error("",e); - } - return fingerprint; - } - - /** - * Save fingerprint to a file - * - * @param fingerprint fingerprint bytes - * @param filename fingerprint filename - * @see FingerprintManager file saved - */ - public void saveFingerprintAsFile(byte[] fingerprint, String filename) { - - FileOutputStream fileOutputStream; - try { - fileOutputStream = new FileOutputStream(filename); - fileOutputStream.write(fingerprint); - fileOutputStream.close(); - } catch (IOException e) { - log.error("",e); - } - } - - // robustLists[x]=y1,y2,y3,... - private List[] getRobustPointList(double[][] spectrogramData) { - - int numX = spectrogramData.length; - int numY = spectrogramData[0].length; - - double[][] allBanksIntensities = new double[numX][numY]; - int bandwidthPerBank = numY / numFilterBanks; - - for (int b = 0; b < numFilterBanks; b++) { - - double[][] bankIntensities = new double[numX][bandwidthPerBank]; - - for (int i = 0; i < numX; i++) { - System.arraycopy(spectrogramData[i], b * bandwidthPerBank, bankIntensities[i], 0, bandwidthPerBank); - } - - // get the most robust point in each filter bank - TopManyPointsProcessorChain processorChain = new TopManyPointsProcessorChain(bankIntensities, 1); - double[][] processedIntensities = processorChain.getIntensities(); - - for (int i = 0; i < numX; i++) { - System.arraycopy(processedIntensities[i], 0, allBanksIntensities[i], b * bandwidthPerBank, - bandwidthPerBank); - } - } - - List robustPointList = new LinkedList(); - - // find robust points - for (int i = 0; i < allBanksIntensities.length; i++) { - for (int j = 0; j < allBanksIntensities[i].length; j++) { - if (allBanksIntensities[i][j] > 0) { - - int[] point = new int[] {i, j}; - //System.out.println(i+","+frequency); - robustPointList.add(point); - } - } - } - // end find robust points - - List[] robustLists = new LinkedList[spectrogramData.length]; - for (int i = 0; i < robustLists.length; i++) { - robustLists[i] = new LinkedList<>(); - } - - // robustLists[x]=y1,y2,y3,... - for (int[] coor : robustPointList) { - robustLists[coor[0]].add(coor[1]); - } - - // return the list per frame - return robustLists; - } - - /** - * Number of frames in a fingerprint - * Each frame lengths 8 bytes - * Usually there is more than one point in each frame, so it cannot simply divide the bytes length by 8 - * Last 8 byte of thisFingerprint is the last frame of this wave - * First 2 byte of the last 8 byte is the x position of this wave, i.e. (number_of_frames-1) of this wave - * - * @param fingerprint fingerprint bytes - * @return number of frames of the fingerprint - */ - public static int getNumFrames(byte[] fingerprint) { - - if (fingerprint.length < 8) { - return 0; - } - - // get the last x-coordinate (length-8&length-7)bytes from fingerprint - return ((fingerprint[fingerprint.length - 8] & 0xff) << 8 | (fingerprint[fingerprint.length - 7] & 0xff)) + 1; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java deleted file mode 100644 index c98fc0a01..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarity.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - - -import org.datavec.audio.properties.FingerprintProperties; - -public class FingerprintSimilarity { - - private FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); - private int mostSimilarFramePosition; - private float score; - private float similarity; - - /** - * Constructor - */ - public FingerprintSimilarity() { - mostSimilarFramePosition = Integer.MIN_VALUE; - score = -1; - similarity = -1; - } - - /** - * Get the most similar position in terms of frame number - * - * @return most similar frame position - */ - public int getMostSimilarFramePosition() { - return mostSimilarFramePosition; - } - - /** - * Set the most similar position in terms of frame number - * - * @param mostSimilarFramePosition - */ - public void setMostSimilarFramePosition(int mostSimilarFramePosition) { - this.mostSimilarFramePosition = mostSimilarFramePosition; - } - - /** - * Get the similarity of the fingerprints - * similarity from 0~1, which 0 means no similar feature is found and 1 means in average there is at least one match in every frame - * - * @return fingerprints similarity - */ - public float getSimilarity() { - return similarity; - } - - /** - * Set the similarity of the fingerprints - * - * @param similarity similarity - */ - public void setSimilarity(float similarity) { - this.similarity = similarity; - } - - /** - * Get the similarity score of the fingerprints - * Number of features found in the fingerprints per frame - * - * @return fingerprints similarity score - */ - public float getScore() { - return score; - } - - /** - * Set the similarity score of the fingerprints - * - * @param score - */ - public void setScore(float score) { - this.score = score; - } - - /** - * Get the most similar position in terms of time in second - * - * @return most similar starting time - */ - public float getsetMostSimilarTimePosition() { - return (float) mostSimilarFramePosition / fingerprintProperties.getNumRobustPointsPerFrame() - / fingerprintProperties.getFps(); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java deleted file mode 100644 index 6fc89b834..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/FingerprintSimilarityComputer.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -import java.util.HashMap; -import java.util.List; - -public class FingerprintSimilarityComputer { - - private FingerprintSimilarity fingerprintSimilarity; - byte[] fingerprint1, fingerprint2; - - /** - * Constructor, ready to compute the similarity of two fingerprints - * - * @param fingerprint1 - * @param fingerprint2 - */ - public FingerprintSimilarityComputer(byte[] fingerprint1, byte[] fingerprint2) { - - this.fingerprint1 = fingerprint1; - this.fingerprint2 = fingerprint2; - - fingerprintSimilarity = new FingerprintSimilarity(); - } - - /** - * Get fingerprint similarity of inout fingerprints - * - * @return fingerprint similarity object - */ - public FingerprintSimilarity getFingerprintsSimilarity() { - HashMap offset_Score_Table = new HashMap<>(); // offset_Score_Table - int numFrames; - float score = 0; - int mostSimilarFramePosition = Integer.MIN_VALUE; - - // one frame may contain several points, use the shorter one be the denominator - if (fingerprint1.length > fingerprint2.length) { - numFrames = FingerprintManager.getNumFrames(fingerprint2); - } else { - numFrames = FingerprintManager.getNumFrames(fingerprint1); - } - - // get the pairs - PairManager pairManager = new PairManager(); - HashMap> this_Pair_PositionList_Table = - pairManager.getPair_PositionList_Table(fingerprint1); - HashMap> compareWave_Pair_PositionList_Table = - pairManager.getPair_PositionList_Table(fingerprint2); - - for (Integer compareWaveHashNumber : compareWave_Pair_PositionList_Table.keySet()) { - // if the compareWaveHashNumber doesn't exist in both tables, no need to compare - if (!this_Pair_PositionList_Table.containsKey(compareWaveHashNumber) - || !compareWave_Pair_PositionList_Table.containsKey(compareWaveHashNumber)) { - continue; - } - - // for each compare hash number, get the positions - List wavePositionList = this_Pair_PositionList_Table.get(compareWaveHashNumber); - List compareWavePositionList = compareWave_Pair_PositionList_Table.get(compareWaveHashNumber); - - for (Integer thisPosition : wavePositionList) { - for (Integer compareWavePosition : compareWavePositionList) { - int offset = thisPosition - compareWavePosition; - if (offset_Score_Table.containsKey(offset)) { - offset_Score_Table.put(offset, offset_Score_Table.get(offset) + 1); - } else { - offset_Score_Table.put(offset, 1); - } - } - } - } - - // map rank - MapRank mapRank = new MapRankInteger(offset_Score_Table, false); - - // get the most similar positions and scores - List orderedKeyList = mapRank.getOrderedKeyList(100, true); - if (orderedKeyList.size() > 0) { - int key = orderedKeyList.get(0); - // get the highest score position - mostSimilarFramePosition = key; - score = offset_Score_Table.get(key); - - // accumulate the scores from neighbours - if (offset_Score_Table.containsKey(key - 1)) { - score += offset_Score_Table.get(key - 1) / 2; - } - if (offset_Score_Table.containsKey(key + 1)) { - score += offset_Score_Table.get(key + 1) / 2; - } - } - - /* - Iterator orderedKeyListIterator=orderedKeyList.iterator(); - while (orderedKeyListIterator.hasNext()){ - int offset=orderedKeyListIterator.next(); - System.out.println(offset+": "+offset_Score_Table.get(offset)); - } - */ - - score /= numFrames; - float similarity = score; - // similarity >1 means in average there is at least one match in every frame - if (similarity > 1) { - similarity = 1; - } - - fingerprintSimilarity.setMostSimilarFramePosition(mostSimilarFramePosition); - fingerprintSimilarity.setScore(score); - fingerprintSimilarity.setSimilarity(similarity); - - return fingerprintSimilarity; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java deleted file mode 100644 index 9fa7142ae..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRank.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -import java.util.List; - -public interface MapRank { - public List getOrderedKeyList(int numKeys, boolean sharpLimit); -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java deleted file mode 100644 index f9cdb9107..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankDouble.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -import java.util.*; -import java.util.Map.Entry; - -public class MapRankDouble implements MapRank { - - private Map map; - private boolean acsending = true; - - public MapRankDouble(Map map, boolean acsending) { - this.map = map; - this.acsending = acsending; - } - - public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value - - Set mapEntrySet = map.entrySet(); - List keyList = new LinkedList(); - - // if the numKeys is larger than map size, limit it - if (numKeys > map.size()) { - numKeys = map.size(); - } - // end if the numKeys is larger than map size, limit it - - if (map.size() > 0) { - double[] array = new double[map.size()]; - int count = 0; - - // get the pass values - Iterator mapIterator = mapEntrySet.iterator(); - while (mapIterator.hasNext()) { - Entry entry = mapIterator.next(); - array[count++] = (Double) entry.getValue(); - } - // end get the pass values - - int targetindex; - if (acsending) { - targetindex = numKeys; - } else { - targetindex = array.length - numKeys; - } - - double passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element - // get the passed keys and values - Map passedMap = new HashMap(); - List valueList = new LinkedList(); - mapIterator = mapEntrySet.iterator(); - - while (mapIterator.hasNext()) { - Entry entry = mapIterator.next(); - double value = (Double) entry.getValue(); - if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { - passedMap.put(entry.getKey(), value); - valueList.add(value); - } - } - // end get the passed keys and values - - // sort the value list - Double[] listArr = new Double[valueList.size()]; - valueList.toArray(listArr); - Arrays.sort(listArr); - // end sort the value list - - // get the list of keys - int resultCount = 0; - int index; - if (acsending) { - index = 0; - } else { - index = listArr.length - 1; - } - - if (!sharpLimit) { - numKeys = listArr.length; - } - - while (true) { - double targetValue = (Double) listArr[index]; - Iterator passedMapIterator = passedMap.entrySet().iterator(); - while (passedMapIterator.hasNext()) { - Entry entry = passedMapIterator.next(); - if ((Double) entry.getValue() == targetValue) { - keyList.add(entry.getKey()); - passedMapIterator.remove(); - resultCount++; - break; - } - } - - if (acsending) { - index++; - } else { - index--; - } - - if (resultCount >= numKeys) { - break; - } - } - // end get the list of keys - } - - return keyList; - } - - private double getOrderedValue(double[] array, int index) { - locate(array, 0, array.length - 1, index); - return array[index]; - } - - // sort the partitions by quick sort, and locate the target index - private void locate(double[] array, int left, int right, int index) { - - int mid = (left + right) / 2; - //System.out.println(left+" to "+right+" ("+mid+")"); - - if (right == left) { - //System.out.println("* "+array[targetIndex]); - //result=array[targetIndex]; - return; - } - - if (left < right) { - double s = array[mid]; - int i = left - 1; - int j = right + 1; - - while (true) { - while (array[++i] < s); - while (array[--j] > s); - if (i >= j) - break; - swap(array, i, j); - } - - //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); - - if (i > index) { - // the target index in the left partition - //System.out.println("left partition"); - locate(array, left, i - 1, index); - } else { - // the target index in the right partition - //System.out.println("right partition"); - locate(array, j + 1, right, index); - } - } - } - - private void swap(double[] array, int i, int j) { - double t = array[i]; - array[i] = array[j]; - array[j] = t; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java deleted file mode 100644 index ed79ffd24..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/MapRankInteger.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -import java.util.*; -import java.util.Map.Entry; - -public class MapRankInteger implements MapRank { - - private Map map; - private boolean acsending = true; - - public MapRankInteger(Map map, boolean acsending) { - this.map = map; - this.acsending = acsending; - } - - public List getOrderedKeyList(int numKeys, boolean sharpLimit) { // if sharp limited, will return sharp numKeys, otherwise will return until the values not equals the exact key's value - - Set mapEntrySet = map.entrySet(); - List keyList = new LinkedList(); - - // if the numKeys is larger than map size, limit it - if (numKeys > map.size()) { - numKeys = map.size(); - } - // end if the numKeys is larger than map size, limit it - - if (map.size() > 0) { - int[] array = new int[map.size()]; - int count = 0; - - // get the pass values - Iterator mapIterator = mapEntrySet.iterator(); - while (mapIterator.hasNext()) { - Entry entry = mapIterator.next(); - array[count++] = (Integer) entry.getValue(); - } - // end get the pass values - - int targetindex; - if (acsending) { - targetindex = numKeys; - } else { - targetindex = array.length - numKeys; - } - - int passValue = getOrderedValue(array, targetindex); // this value is the value of the numKey-th element - // get the passed keys and values - Map passedMap = new HashMap(); - List valueList = new LinkedList(); - mapIterator = mapEntrySet.iterator(); - - while (mapIterator.hasNext()) { - Entry entry = mapIterator.next(); - int value = (Integer) entry.getValue(); - if ((acsending && value <= passValue) || (!acsending && value >= passValue)) { - passedMap.put(entry.getKey(), value); - valueList.add(value); - } - } - // end get the passed keys and values - - // sort the value list - Integer[] listArr = new Integer[valueList.size()]; - valueList.toArray(listArr); - Arrays.sort(listArr); - // end sort the value list - - // get the list of keys - int resultCount = 0; - int index; - if (acsending) { - index = 0; - } else { - index = listArr.length - 1; - } - - if (!sharpLimit) { - numKeys = listArr.length; - } - - while (true) { - int targetValue = (Integer) listArr[index]; - Iterator passedMapIterator = passedMap.entrySet().iterator(); - while (passedMapIterator.hasNext()) { - Entry entry = passedMapIterator.next(); - if ((Integer) entry.getValue() == targetValue) { - keyList.add(entry.getKey()); - passedMapIterator.remove(); - resultCount++; - break; - } - } - - if (acsending) { - index++; - } else { - index--; - } - - if (resultCount >= numKeys) { - break; - } - } - // end get the list of keys - } - - return keyList; - } - - private int getOrderedValue(int[] array, int index) { - locate(array, 0, array.length - 1, index); - return array[index]; - } - - // sort the partitions by quick sort, and locate the target index - private void locate(int[] array, int left, int right, int index) { - - int mid = (left + right) / 2; - //System.out.println(left+" to "+right+" ("+mid+")"); - - if (right == left) { - //System.out.println("* "+array[targetIndex]); - //result=array[targetIndex]; - return; - } - - if (left < right) { - int s = array[mid]; - int i = left - 1; - int j = right + 1; - - while (true) { - while (array[++i] < s); - while (array[--j] > s); - if (i >= j) - break; - swap(array, i, j); - } - - //System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); - - if (i > index) { - // the target index in the left partition - //System.out.println("left partition"); - locate(array, left, i - 1, index); - } else { - // the target index in the right partition - //System.out.println("right partition"); - locate(array, j + 1, right, index); - } - } - } - - private void swap(int[] array, int i, int j) { - int t = array[i]; - array[i] = array[j]; - array[j] = t; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java deleted file mode 100644 index 71c608c65..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/PairManager.java +++ /dev/null @@ -1,230 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - - - -import org.datavec.audio.properties.FingerprintProperties; - -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; - -public class PairManager { - - FingerprintProperties fingerprintProperties = FingerprintProperties.getInstance(); - private int numFilterBanks = fingerprintProperties.getNumFilterBanks(); - private int bandwidthPerBank = fingerprintProperties.getNumFrequencyUnits() / numFilterBanks; - private int anchorPointsIntervalLength = fingerprintProperties.getAnchorPointsIntervalLength(); - private int numAnchorPointsPerInterval = fingerprintProperties.getNumAnchorPointsPerInterval(); - private int maxTargetZoneDistance = fingerprintProperties.getMaxTargetZoneDistance(); - private int numFrequencyUnits = fingerprintProperties.getNumFrequencyUnits(); - - private int maxPairs; - private boolean isReferencePairing; - private HashMap stopPairTable = new HashMap<>(); - - /** - * Constructor - */ - public PairManager() { - maxPairs = fingerprintProperties.getRefMaxActivePairs(); - isReferencePairing = true; - } - - /** - * Constructor, number of pairs of robust points depends on the parameter isReferencePairing - * no. of pairs of reference and sample can be different due to environmental influence of source - * @param isReferencePairing - */ - public PairManager(boolean isReferencePairing) { - if (isReferencePairing) { - maxPairs = fingerprintProperties.getRefMaxActivePairs(); - } else { - maxPairs = fingerprintProperties.getSampleMaxActivePairs(); - } - this.isReferencePairing = isReferencePairing; - } - - /** - * Get a pair-positionList table - * It's a hash map which the key is the hashed pair, and the value is list of positions - * That means the table stores the positions which have the same hashed pair - * - * @param fingerprint fingerprint bytes - * @return pair-positionList HashMap - */ - public HashMap> getPair_PositionList_Table(byte[] fingerprint) { - - List pairPositionList = getPairPositionList(fingerprint); - - // table to store pair:pos,pos,pos,...;pair2:pos,pos,pos,.... - HashMap> pair_positionList_table = new HashMap<>(); - - // get all pair_positions from list, use a table to collect the data group by pair hashcode - for (int[] pair_position : pairPositionList) { - //System.out.println(pair_position[0]+","+pair_position[1]); - - // group by pair-hashcode, i.e.: > - if (pair_positionList_table.containsKey(pair_position[0])) { - pair_positionList_table.get(pair_position[0]).add(pair_position[1]); - } else { - List positionList = new LinkedList<>(); - positionList.add(pair_position[1]); - pair_positionList_table.put(pair_position[0], positionList); - } - // end group by pair-hashcode, i.e.: > - } - // end get all pair_positions from list, use a table to collect the data group by pair hashcode - - return pair_positionList_table; - } - - // this return list contains: int[0]=pair_hashcode, int[1]=position - private List getPairPositionList(byte[] fingerprint) { - - int numFrames = FingerprintManager.getNumFrames(fingerprint); - - // table for paired frames - byte[] pairedFrameTable = new byte[numFrames / anchorPointsIntervalLength + 1]; // each second has numAnchorPointsPerSecond pairs only - // end table for paired frames - - List pairList = new LinkedList<>(); - List sortedCoordinateList = getSortedCoordinateList(fingerprint); - - for (int[] anchorPoint : sortedCoordinateList) { - int anchorX = anchorPoint[0]; - int anchorY = anchorPoint[1]; - int numPairs = 0; - - for (int[] aSortedCoordinateList : sortedCoordinateList) { - - if (numPairs >= maxPairs) { - break; - } - - if (isReferencePairing && pairedFrameTable[anchorX - / anchorPointsIntervalLength] >= numAnchorPointsPerInterval) { - break; - } - - int targetX = aSortedCoordinateList[0]; - int targetY = aSortedCoordinateList[1]; - - if (anchorX == targetX && anchorY == targetY) { - continue; - } - - // pair up the points - int x1, y1, x2, y2; // x2 always >= x1 - if (targetX >= anchorX) { - x2 = targetX; - y2 = targetY; - x1 = anchorX; - y1 = anchorY; - } else { - x2 = anchorX; - y2 = anchorY; - x1 = targetX; - y1 = targetY; - } - - // check target zone - if ((x2 - x1) > maxTargetZoneDistance) { - continue; - } - // end check target zone - - // check filter bank zone - if (!(y1 / bandwidthPerBank == y2 / bandwidthPerBank)) { - continue; // same filter bank should have equal value - } - // end check filter bank zone - - int pairHashcode = (x2 - x1) * numFrequencyUnits * numFrequencyUnits + y2 * numFrequencyUnits + y1; - - // stop list applied on sample pairing only - if (!isReferencePairing && stopPairTable.containsKey(pairHashcode)) { - numPairs++; // no reservation - continue; // escape this point only - } - // end stop list applied on sample pairing only - - // pass all rules - pairList.add(new int[] {pairHashcode, anchorX}); - pairedFrameTable[anchorX / anchorPointsIntervalLength]++; - numPairs++; - // end pair up the points - } - } - - return pairList; - } - - private List getSortedCoordinateList(byte[] fingerprint) { - // each point data is 8 bytes - // first 2 bytes is x - // next 2 bytes is y - // next 4 bytes is intensity - - // get all intensities - int numCoordinates = fingerprint.length / 8; - int[] intensities = new int[numCoordinates]; - for (int i = 0; i < numCoordinates; i++) { - int pointer = i * 8 + 4; - int intensity = (fingerprint[pointer] & 0xff) << 24 | (fingerprint[pointer + 1] & 0xff) << 16 - | (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); - intensities[i] = intensity; - } - - QuickSortIndexPreserved quicksort = new QuickSortIndexPreserved(intensities); - int[] sortIndexes = quicksort.getSortIndexes(); - - List sortedCoordinateList = new LinkedList<>(); - for (int i = sortIndexes.length - 1; i >= 0; i--) { - int pointer = sortIndexes[i] * 8; - int x = (fingerprint[pointer] & 0xff) << 8 | (fingerprint[pointer + 1] & 0xff); - int y = (fingerprint[pointer + 2] & 0xff) << 8 | (fingerprint[pointer + 3] & 0xff); - sortedCoordinateList.add(new int[] {x, y}); - } - return sortedCoordinateList; - } - - /** - * Convert hashed pair to bytes - * - * @param pairHashcode hashed pair - * @return byte array - */ - public static byte[] pairHashcodeToBytes(int pairHashcode) { - return new byte[] {(byte) (pairHashcode >> 8), (byte) pairHashcode}; - } - - /** - * Convert bytes to hased pair - * - * @param pairBytes - * @return hashed pair - */ - public static int pairBytesToHashcode(byte[] pairBytes) { - return (pairBytes[0] & 0xFF) << 8 | (pairBytes[1] & 0xFF); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSort.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSort.java deleted file mode 100644 index 5ff0fe31b..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSort.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -public abstract class QuickSort { - public abstract int[] getSortIndexes(); -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java deleted file mode 100644 index bf8939298..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortDouble.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -public class QuickSortDouble extends QuickSort { - - private int[] indexes; - private double[] array; - - public QuickSortDouble(double[] array) { - this.array = array; - indexes = new int[array.length]; - for (int i = 0; i < indexes.length; i++) { - indexes[i] = i; - } - } - - public int[] getSortIndexes() { - sort(); - return indexes; - } - - private void sort() { - quicksort(array, indexes, 0, indexes.length - 1); - } - - // quicksort a[left] to a[right] - private void quicksort(double[] a, int[] indexes, int left, int right) { - if (right <= left) - return; - int i = partition(a, indexes, left, right); - quicksort(a, indexes, left, i - 1); - quicksort(a, indexes, i + 1, right); - } - - // partition a[left] to a[right], assumes left < right - private int partition(double[] a, int[] indexes, int left, int right) { - int i = left - 1; - int j = right; - while (true) { - while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel - while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap - if (j == left) - break; // don't go out-of-bounds - } - if (i >= j) - break; // check if pointers cross - swap(a, indexes, i, j); // swap two elements into place - } - swap(a, indexes, i, right); // swap with partition element - return i; - } - - // exchange a[i] and a[j] - private void swap(double[] a, int[] indexes, int i, int j) { - int swap = indexes[i]; - indexes[i] = indexes[j]; - indexes[j] = swap; - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java deleted file mode 100644 index 9a21b74e5..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortIndexPreserved.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -public class QuickSortIndexPreserved { - - private QuickSort quickSort; - - public QuickSortIndexPreserved(int[] array) { - quickSort = new QuickSortInteger(array); - } - - public QuickSortIndexPreserved(double[] array) { - quickSort = new QuickSortDouble(array); - } - - public QuickSortIndexPreserved(short[] array) { - quickSort = new QuickSortShort(array); - } - - public int[] getSortIndexes() { - return quickSort.getSortIndexes(); - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java deleted file mode 100644 index a89318910..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortInteger.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -public class QuickSortInteger extends QuickSort { - - private int[] indexes; - private int[] array; - - public QuickSortInteger(int[] array) { - this.array = array; - indexes = new int[array.length]; - for (int i = 0; i < indexes.length; i++) { - indexes[i] = i; - } - } - - public int[] getSortIndexes() { - sort(); - return indexes; - } - - private void sort() { - quicksort(array, indexes, 0, indexes.length - 1); - } - - // quicksort a[left] to a[right] - private void quicksort(int[] a, int[] indexes, int left, int right) { - if (right <= left) - return; - int i = partition(a, indexes, left, right); - quicksort(a, indexes, left, i - 1); - quicksort(a, indexes, i + 1, right); - } - - // partition a[left] to a[right], assumes left < right - private int partition(int[] a, int[] indexes, int left, int right) { - int i = left - 1; - int j = right; - while (true) { - while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel - while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap - if (j == left) - break; // don't go out-of-bounds - } - if (i >= j) - break; // check if pointers cross - swap(a, indexes, i, j); // swap two elements into place - } - swap(a, indexes, i, right); // swap with partition element - return i; - } - - // exchange a[i] and a[j] - private void swap(int[] a, int[] indexes, int i, int j) { - int swap = indexes[i]; - indexes[i] = indexes[j]; - indexes[j] = swap; - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java deleted file mode 100644 index 5230740d9..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/fingerprint/QuickSortShort.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.fingerprint; - -public class QuickSortShort extends QuickSort { - - private int[] indexes; - private short[] array; - - public QuickSortShort(short[] array) { - this.array = array; - indexes = new int[array.length]; - for (int i = 0; i < indexes.length; i++) { - indexes[i] = i; - } - } - - public int[] getSortIndexes() { - sort(); - return indexes; - } - - private void sort() { - quicksort(array, indexes, 0, indexes.length - 1); - } - - // quicksort a[left] to a[right] - private void quicksort(short[] a, int[] indexes, int left, int right) { - if (right <= left) - return; - int i = partition(a, indexes, left, right); - quicksort(a, indexes, left, i - 1); - quicksort(a, indexes, i + 1, right); - } - - // partition a[left] to a[right], assumes left < right - private int partition(short[] a, int[] indexes, int left, int right) { - int i = left - 1; - int j = right; - while (true) { - while (a[indexes[++i]] < a[indexes[right]]); // find item on left to swap, a[right] acts as sentinel - while (a[indexes[right]] < a[indexes[--j]]) { // find item on right to swap - if (j == left) - break; // don't go out-of-bounds - } - if (i >= j) - break; // check if pointers cross - swap(a, indexes, i, j); // swap two elements into place - } - swap(a, indexes, i, right); // swap with partition element - return i; - } - - // exchange a[i] and a[j] - private void swap(short[] a, int[] indexes, int i, int j) { - int swap = indexes[i]; - indexes[i] = indexes[j]; - indexes[j] = swap; - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/input/WavInputFormat.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/input/WavInputFormat.java deleted file mode 100644 index d78639972..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/input/WavInputFormat.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.formats.input; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.BaseInputFormat; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.audio.recordreader.WavFileRecordReader; - -import java.io.IOException; - -/** - * - * Wave file input format - * - * @author Adam Gibson - */ -public class WavInputFormat extends BaseInputFormat { - @Override - public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { - return createReader(split); - } - - @Override - public RecordReader createReader(InputSplit split) throws IOException, InterruptedException { - RecordReader waveRecordReader = new WavFileRecordReader(); - waveRecordReader.initialize(split); - return waveRecordReader; - } - - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/output/WaveOutputFormat.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/output/WaveOutputFormat.java deleted file mode 100644 index 986508268..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/formats/output/WaveOutputFormat.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.formats.output; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.exceptions.DataVecException; -import org.datavec.api.formats.output.OutputFormat; -import org.datavec.api.records.writer.RecordWriter; - -/** - * @author Adam Gibson - */ -public class WaveOutputFormat implements OutputFormat { - @Override - public RecordWriter createWriter(Configuration conf) throws DataVecException { - return null; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ArrayRankDouble.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ArrayRankDouble.java deleted file mode 100644 index eb8b537ef..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ArrayRankDouble.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.processor; - -public class ArrayRankDouble { - - /** - * Get the index position of maximum value the given array - * @param array an array - * @return index of the max value in array - */ - public int getMaxValueIndex(double[] array) { - - int index = 0; - double max = Integer.MIN_VALUE; - - for (int i = 0; i < array.length; i++) { - if (array[i] > max) { - max = array[i]; - index = i; - } - } - - return index; - } - - /** - * Get the index position of minimum value in the given array - * @param array an array - * @return index of the min value in array - */ - public int getMinValueIndex(double[] array) { - - int index = 0; - double min = Integer.MAX_VALUE; - - for (int i = 0; i < array.length; i++) { - if (array[i] < min) { - min = array[i]; - index = i; - } - } - - return index; - } - - /** - * Get the n-th value in the array after sorted - * @param array an array - * @param n position in array - * @param ascending is ascending order or not - * @return value at nth position of array - */ - public double getNthOrderedValue(double[] array, int n, boolean ascending) { - - if (n > array.length) { - n = array.length; - } - - int targetindex; - if (ascending) { - targetindex = n; - } else { - targetindex = array.length - n; - } - - // this value is the value of the numKey-th element - - return getOrderedValue(array, targetindex); - } - - private double getOrderedValue(double[] array, int index) { - locate(array, 0, array.length - 1, index); - return array[index]; - } - - // sort the partitions by quick sort, and locate the target index - private void locate(double[] array, int left, int right, int index) { - - int mid = (left + right) / 2; - // System.out.println(left+" to "+right+" ("+mid+")"); - - if (right == left) { - // System.out.println("* "+array[targetIndex]); - // result=array[targetIndex]; - return; - } - - if (left < right) { - double s = array[mid]; - int i = left - 1; - int j = right + 1; - - while (true) { - while (array[++i] < s); - while (array[--j] > s); - if (i >= j) - break; - swap(array, i, j); - } - - // System.out.println("2 parts: "+left+"-"+(i-1)+" and "+(j+1)+"-"+right); - - if (i > index) { - // the target index in the left partition - // System.out.println("left partition"); - locate(array, left, i - 1, index); - } else { - // the target index in the right partition - // System.out.println("right partition"); - locate(array, j + 1, right, index); - } - } - } - - private void swap(double[] array, int i, int j) { - double t = array[i]; - array[i] = array[j]; - array[j] = t; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java deleted file mode 100644 index 3f49e6a58..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/IntensityProcessor.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.processor; - -public interface IntensityProcessor { - - public void execute(); - - public double[][] getIntensities(); -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ProcessorChain.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ProcessorChain.java deleted file mode 100644 index 8a0d6a5e1..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/ProcessorChain.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.processor; - -import java.util.LinkedList; -import java.util.List; - -public class ProcessorChain { - - private double[][] intensities; - List processorList = new LinkedList(); - - public ProcessorChain(double[][] intensities) { - this.intensities = intensities; - RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, 1); - processorList.add(robustProcessor); - process(); - } - - private void process() { - for (IntensityProcessor processor : processorList) { - processor.execute(); - intensities = processor.getIntensities(); - } - } - - public double[][] getIntensities() { - return intensities; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java deleted file mode 100644 index 91ba4806c..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/RobustIntensityProcessor.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.processor; - - -public class RobustIntensityProcessor implements IntensityProcessor { - - private double[][] intensities; - private int numPointsPerFrame; - - public RobustIntensityProcessor(double[][] intensities, int numPointsPerFrame) { - this.intensities = intensities; - this.numPointsPerFrame = numPointsPerFrame; - } - - public void execute() { - - int numX = intensities.length; - int numY = intensities[0].length; - double[][] processedIntensities = new double[numX][numY]; - - for (int i = 0; i < numX; i++) { - double[] tmpArray = new double[numY]; - System.arraycopy(intensities[i], 0, tmpArray, 0, numY); - - // pass value is the last some elements in sorted array - ArrayRankDouble arrayRankDouble = new ArrayRankDouble(); - double passValue = arrayRankDouble.getNthOrderedValue(tmpArray, numPointsPerFrame, false); - - // only passed elements will be assigned a value - for (int j = 0; j < numY; j++) { - if (intensities[i][j] >= passValue) { - processedIntensities[i][j] = intensities[i][j]; - } - } - } - intensities = processedIntensities; - } - - public double[][] getIntensities() { - return intensities; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/TopManyPointsProcessorChain.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/TopManyPointsProcessorChain.java deleted file mode 100644 index c190e9bd8..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/processor/TopManyPointsProcessorChain.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.processor; - -import java.util.LinkedList; -import java.util.List; - - -public class TopManyPointsProcessorChain { - - private double[][] intensities; - List processorList = new LinkedList<>(); - - public TopManyPointsProcessorChain(double[][] intensities, int numPoints) { - this.intensities = intensities; - RobustIntensityProcessor robustProcessor = new RobustIntensityProcessor(intensities, numPoints); - processorList.add(robustProcessor); - process(); - } - - private void process() { - for (IntensityProcessor processor : processorList) { - processor.execute(); - intensities = processor.getIntensities(); - } - } - - public double[][] getIntensities() { - return intensities; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java deleted file mode 100644 index 0075cfa63..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/properties/FingerprintProperties.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.properties; - -public class FingerprintProperties { - - protected static FingerprintProperties instance = null; - - private int numRobustPointsPerFrame = 4; // number of points in each frame, i.e. top 4 intensities in fingerprint - private int sampleSizePerFrame = 2048; // number of audio samples in a frame, it is suggested to be the FFT Size - private int overlapFactor = 4; // 8 means each move 1/8 nSample length. 1 means no overlap, better 1,2,4,8 ... 32 - private int numFilterBanks = 4; - - private int upperBoundedFrequency = 1500; // low pass - private int lowerBoundedFrequency = 400; // high pass - private int fps = 5; // in order to have 5fps with 2048 sampleSizePerFrame, wave's sample rate need to be 10240 (sampleSizePerFrame*fps) - private int sampleRate = sampleSizePerFrame * fps; // the audio's sample rate needed to resample to this in order to fit the sampleSizePerFrame and fps - private int numFramesInOneSecond = overlapFactor * fps; // since the overlap factor affects the actual number of fps, so this value is used to evaluate how many frames in one second eventually - - private int refMaxActivePairs = 1; // max. active pairs per anchor point for reference songs - private int sampleMaxActivePairs = 10; // max. active pairs per anchor point for sample clip - private int numAnchorPointsPerInterval = 10; - private int anchorPointsIntervalLength = 4; // in frames (5fps,4 overlap per second) - private int maxTargetZoneDistance = 4; // in frame (5fps,4 overlap per second) - - private int numFrequencyUnits = (upperBoundedFrequency - lowerBoundedFrequency + 1) / fps + 1; // num frequency units - - public static FingerprintProperties getInstance() { - if (instance == null) { - synchronized (FingerprintProperties.class) { - if (instance == null) { - instance = new FingerprintProperties(); - } - } - } - return instance; - } - - public int getNumRobustPointsPerFrame() { - return numRobustPointsPerFrame; - } - - public int getSampleSizePerFrame() { - return sampleSizePerFrame; - } - - public int getOverlapFactor() { - return overlapFactor; - } - - public int getNumFilterBanks() { - return numFilterBanks; - } - - public int getUpperBoundedFrequency() { - return upperBoundedFrequency; - } - - public int getLowerBoundedFrequency() { - return lowerBoundedFrequency; - } - - public int getFps() { - return fps; - } - - public int getRefMaxActivePairs() { - return refMaxActivePairs; - } - - public int getSampleMaxActivePairs() { - return sampleMaxActivePairs; - } - - public int getNumAnchorPointsPerInterval() { - return numAnchorPointsPerInterval; - } - - public int getAnchorPointsIntervalLength() { - return anchorPointsIntervalLength; - } - - public int getMaxTargetZoneDistance() { - return maxTargetZoneDistance; - } - - public int getNumFrequencyUnits() { - return numFrequencyUnits; - } - - public int getMaxPossiblePairHashcode() { - return maxTargetZoneDistance * numFrequencyUnits * numFrequencyUnits + numFrequencyUnits * numFrequencyUnits - + numFrequencyUnits; - } - - public int getSampleRate() { - return sampleRate; - } - - public int getNumFramesInOneSecond() { - return numFramesInOneSecond; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java deleted file mode 100644 index 4702f9d89..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/BaseAudioRecordReader.java +++ /dev/null @@ -1,225 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.recordreader; - -import org.apache.commons.io.FileUtils; -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; -import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.records.reader.BaseRecordReader; -import org.datavec.api.split.BaseInputSplit; -import org.datavec.api.split.InputSplit; -import org.datavec.api.split.InputStreamInputSplit; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; - -import java.io.DataInputStream; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.nio.file.Path; -import java.nio.file.Paths; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; - -/** - * Base audio file loader - * @author Adam Gibson - */ -public abstract class BaseAudioRecordReader extends BaseRecordReader { - private Iterator iter; - private List record; - private boolean hitImage = false; - private boolean appendLabel = false; - private List labels = new ArrayList<>(); - private Configuration conf; - protected InputSplit inputSplit; - - public BaseAudioRecordReader() {} - - public BaseAudioRecordReader(boolean appendLabel, List labels) { - this.appendLabel = appendLabel; - this.labels = labels; - } - - public BaseAudioRecordReader(List labels) { - this.labels = labels; - } - - public BaseAudioRecordReader(boolean appendLabel) { - this.appendLabel = appendLabel; - } - - protected abstract List loadData(File file, InputStream inputStream) throws IOException; - - @Override - public void initialize(InputSplit split) throws IOException, InterruptedException { - inputSplit = split; - if (split instanceof BaseInputSplit) { - URI[] locations = split.locations(); - if (locations != null && locations.length >= 1) { - if (locations.length > 1) { - List allFiles = new ArrayList<>(); - for (URI location : locations) { - File iter = new File(location); - if (iter.isDirectory()) { - Iterator allFiles2 = FileUtils.iterateFiles(iter, null, true); - while (allFiles2.hasNext()) - allFiles.add(allFiles2.next()); - } - - else - allFiles.add(iter); - } - - iter = allFiles.iterator(); - } else { - File curr = new File(locations[0]); - if (curr.isDirectory()) - iter = FileUtils.iterateFiles(curr, null, true); - else - iter = Collections.singletonList(curr).iterator(); - } - } - } - - - else if (split instanceof InputStreamInputSplit) { - record = new ArrayList<>(); - InputStreamInputSplit split2 = (InputStreamInputSplit) split; - InputStream is = split2.getIs(); - URI[] locations = split2.locations(); - if (appendLabel) { - Path path = Paths.get(locations[0]); - String parent = path.getParent().toString(); - record.add(new DoubleWritable(labels.indexOf(parent))); - } - - is.close(); - } - - } - - @Override - public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { - this.conf = conf; - this.appendLabel = conf.getBoolean(APPEND_LABEL, false); - this.labels = new ArrayList<>(conf.getStringCollection(LABELS)); - initialize(split); - } - - @Override - public List next() { - if (iter != null) { - File next = iter.next(); - invokeListeners(next); - try { - return loadData(next, null); - } catch (Exception e) { - throw new RuntimeException(e); - } - } else if (record != null) { - hitImage = true; - return record; - } - - throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); - } - - @Override - public boolean hasNext() { - if (iter != null) { - return iter.hasNext(); - } else if (record != null) { - return !hitImage; - } - throw new IllegalStateException("Indeterminant state: record must not be null, or a file iterator must exist"); - } - - - @Override - public void close() throws IOException { - - } - - @Override - public void setConf(Configuration conf) { - this.conf = conf; - } - - @Override - public Configuration getConf() { - return conf; - } - - @Override - public List getLabels() { - return null; - } - - - @Override - public void reset() { - if (inputSplit == null) - throw new UnsupportedOperationException("Cannot reset without first initializing"); - try { - initialize(inputSplit); - } catch (Exception e) { - throw new RuntimeException("Error during LineRecordReader reset", e); - } - } - - @Override - public boolean resetSupported(){ - if(inputSplit == null){ - return false; - } - return inputSplit.resetSupported(); - } - - @Override - public List record(URI uri, DataInputStream dataInputStream) throws IOException { - invokeListeners(uri); - try { - return loadData(null, dataInputStream); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public Record nextRecord() { - return new org.datavec.api.records.impl.Record(next(), null); - } - - @Override - public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { - throw new UnsupportedOperationException("Loading from metadata not yet implemented"); - } - - @Override - public List loadFromMetaData(List recordMetaDatas) throws IOException { - throw new UnsupportedOperationException("Loading from metadata not yet implemented"); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java deleted file mode 100644 index 62278e2d4..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/NativeAudioRecordReader.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.recordreader; - -import org.bytedeco.javacv.FFmpegFrameGrabber; -import org.bytedeco.javacv.Frame; -import org.datavec.api.writable.FloatWritable; -import org.datavec.api.writable.Writable; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.nio.FloatBuffer; -import java.util.ArrayList; -import java.util.List; - -import static org.bytedeco.ffmpeg.global.avutil.AV_SAMPLE_FMT_FLT; - -/** - * Native audio file loader using FFmpeg. - * - * @author saudet - */ -public class NativeAudioRecordReader extends BaseAudioRecordReader { - - public NativeAudioRecordReader() {} - - public NativeAudioRecordReader(boolean appendLabel, List labels) { - super(appendLabel, labels); - } - - public NativeAudioRecordReader(List labels) { - super(labels); - } - - public NativeAudioRecordReader(boolean appendLabel) { - super(appendLabel); - } - - protected List loadData(File file, InputStream inputStream) throws IOException { - List ret = new ArrayList<>(); - try (FFmpegFrameGrabber grabber = inputStream != null ? new FFmpegFrameGrabber(inputStream) - : new FFmpegFrameGrabber(file.getAbsolutePath())) { - grabber.setSampleFormat(AV_SAMPLE_FMT_FLT); - grabber.start(); - Frame frame; - while ((frame = grabber.grab()) != null) { - while (frame.samples != null && frame.samples[0].hasRemaining()) { - for (int i = 0; i < frame.samples.length; i++) { - ret.add(new FloatWritable(((FloatBuffer) frame.samples[i]).get())); - } - } - } - } - return ret; - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java b/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java deleted file mode 100644 index 60f764a14..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/main/java/org/datavec/audio/recordreader/WavFileRecordReader.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio.recordreader; - -import org.datavec.api.util.RecordUtils; -import org.datavec.api.writable.Writable; -import org.datavec.audio.Wave; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.util.List; - -/** - * Wav file loader - * @author Adam Gibson - */ -public class WavFileRecordReader extends BaseAudioRecordReader { - - public WavFileRecordReader() {} - - public WavFileRecordReader(boolean appendLabel, List labels) { - super(appendLabel, labels); - } - - public WavFileRecordReader(List labels) { - super(labels); - } - - public WavFileRecordReader(boolean appendLabel) { - super(appendLabel); - } - - protected List loadData(File file, InputStream inputStream) throws IOException { - Wave wave = inputStream != null ? new Wave(inputStream) : new Wave(file.getAbsolutePath()); - return RecordUtils.toRecord(wave.getNormalizedAmplitudes()); - } - -} diff --git a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AssertTestsExtendBaseClass.java deleted file mode 100644 index 14b8459bb..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.audio; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - public long getTimeoutMilliseconds() { - return 60000; - } - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.audio"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java b/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java deleted file mode 100644 index f2ee66345..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/AudioReaderTest.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio; - -import org.bytedeco.javacv.FFmpegFrameRecorder; -import org.bytedeco.javacv.Frame; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; -import org.datavec.audio.recordreader.NativeAudioRecordReader; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.common.tests.BaseND4JTest; - -import java.io.File; -import java.nio.ShortBuffer; -import java.util.List; - -import static org.bytedeco.ffmpeg.global.avcodec.AV_CODEC_ID_VORBIS; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -/** - * @author saudet - */ -public class AudioReaderTest extends BaseND4JTest { - @Ignore - @Test - public void testNativeAudioReader() throws Exception { - File tempFile = File.createTempFile("testNativeAudioReader", ".ogg"); - FFmpegFrameRecorder recorder = new FFmpegFrameRecorder(tempFile, 2); - recorder.setAudioCodec(AV_CODEC_ID_VORBIS); - recorder.setSampleRate(44100); - recorder.start(); - Frame audioFrame = new Frame(); - ShortBuffer audioBuffer = ShortBuffer.allocate(64 * 1024); - audioFrame.sampleRate = 44100; - audioFrame.audioChannels = 2; - audioFrame.samples = new ShortBuffer[] {audioBuffer}; - recorder.record(audioFrame); - recorder.stop(); - recorder.release(); - - RecordReader reader = new NativeAudioRecordReader(); - reader.initialize(new FileSplit(tempFile)); - assertTrue(reader.hasNext()); - List record = reader.next(); - assertEquals(audioBuffer.limit(), record.size()); - } -} diff --git a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/TestFastFourierTransform.java b/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/TestFastFourierTransform.java deleted file mode 100644 index 9b35beac6..000000000 --- a/datavec/datavec-data/datavec-data-audio/src/test/java/org/datavec/audio/TestFastFourierTransform.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.audio; - -import org.datavec.audio.dsp.FastFourierTransform; -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.common.tests.BaseND4JTest; - -public class TestFastFourierTransform extends BaseND4JTest { - - @Test - public void testFastFourierTransformComplex() { - FastFourierTransform fft = new FastFourierTransform(); - double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; - double[] frequencies = fft.getMagnitudes(amplitudes); - - Assert.assertEquals(2, frequencies.length); - Assert.assertArrayEquals(new double[] {21.335, 18.513}, frequencies, 0.005); - } - - @Test - public void testFastFourierTransformComplexLong() { - FastFourierTransform fft = new FastFourierTransform(); - double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; - double[] frequencies = fft.getMagnitudes(amplitudes, true); - - Assert.assertEquals(4, frequencies.length); - Assert.assertArrayEquals(new double[] {21.335, 18.5132, 14.927, 7.527}, frequencies, 0.005); - } - - @Test - public void testFastFourierTransformReal() { - FastFourierTransform fft = new FastFourierTransform(); - double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5, 4.6}; - double[] frequencies = fft.getMagnitudes(amplitudes, false); - - Assert.assertEquals(4, frequencies.length); - Assert.assertArrayEquals(new double[] {28.8, 2.107, 14.927, 19.874}, frequencies, 0.005); - } - - @Test - public void testFastFourierTransformRealOddSize() { - FastFourierTransform fft = new FastFourierTransform(); - double[] amplitudes = new double[] {3.0, 4.0, 0.5, 7.8, 6.9, -6.5, 8.5}; - double[] frequencies = fft.getMagnitudes(amplitudes, false); - - Assert.assertEquals(3, frequencies.length); - Assert.assertArrayEquals(new double[] {24.2, 3.861, 16.876}, frequencies, 0.005); - } -} diff --git a/datavec/datavec-data/datavec-data-codec/pom.xml b/datavec/datavec-data/datavec-data-codec/pom.xml deleted file mode 100644 index 0a57f2d90..000000000 --- a/datavec/datavec-data/datavec-data-codec/pom.xml +++ /dev/null @@ -1,71 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-data - 1.0.0-SNAPSHOT - - - datavec-data-codec - - datavec-data-codec - - - - org.datavec - datavec-api - - - org.datavec - datavec-data-image - ${project.version} - - - org.jcodec - jcodec - 0.1.5 - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/format/input/CodecInputFormat.java b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/format/input/CodecInputFormat.java deleted file mode 100644 index b2ea9628f..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/format/input/CodecInputFormat.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.codec.format.input; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.BaseInputFormat; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.codec.reader.CodecRecordReader; - -import java.io.IOException; - -/** - * @author Adam Gibson - */ -public class CodecInputFormat extends BaseInputFormat { - @Override - public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { - RecordReader reader = new CodecRecordReader(); - reader.initialize(conf, split); - return reader; - } -} diff --git a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java deleted file mode 100644 index 09d660915..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/BaseCodecRecordReader.java +++ /dev/null @@ -1,144 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.codec.reader; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.SequenceRecord; -import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.records.metadata.RecordMetaDataURI; -import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.records.reader.impl.FileRecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; - -import java.io.DataInputStream; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.net.URI; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public abstract class BaseCodecRecordReader extends FileRecordReader implements SequenceRecordReader { - protected int startFrame = 0; - protected int numFrames = -1; - protected int totalFrames = -1; - protected double framesPerSecond = -1; - protected double videoLength = -1; - protected int rows = 28, cols = 28; - protected boolean ravel = false; - - public final static String NAME_SPACE = "org.datavec.codec.reader"; - public final static String ROWS = NAME_SPACE + ".rows"; - public final static String COLUMNS = NAME_SPACE + ".columns"; - public final static String START_FRAME = NAME_SPACE + ".startframe"; - public final static String TOTAL_FRAMES = NAME_SPACE + ".frames"; - public final static String TIME_SLICE = NAME_SPACE + ".time"; - public final static String RAVEL = NAME_SPACE + ".ravel"; - public final static String VIDEO_DURATION = NAME_SPACE + ".duration"; - - - @Override - public List> sequenceRecord() { - URI next = locationsIterator.next(); - - try (InputStream s = streamCreatorFn.apply(next)){ - return loadData(null, s); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public List> sequenceRecord(URI uri, DataInputStream dataInputStream) throws IOException { - return loadData(null, dataInputStream); - } - - protected abstract List> loadData(File file, InputStream inputStream) throws IOException; - - - @Override - public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { - setConf(conf); - initialize(split); - } - - @Override - public List next() { - throw new UnsupportedOperationException("next() not supported for CodecRecordReader (use: sequenceRecord)"); - } - - @Override - public List record(URI uri, DataInputStream dataInputStream) throws IOException { - throw new UnsupportedOperationException("record(URI,DataInputStream) not supported for CodecRecordReader"); - } - - @Override - public void setConf(Configuration conf) { - super.setConf(conf); - startFrame = conf.getInt(START_FRAME, 0); - numFrames = conf.getInt(TOTAL_FRAMES, -1); - rows = conf.getInt(ROWS, 28); - cols = conf.getInt(COLUMNS, 28); - framesPerSecond = conf.getFloat(TIME_SLICE, -1); - videoLength = conf.getFloat(VIDEO_DURATION, -1); - ravel = conf.getBoolean(RAVEL, false); - totalFrames = conf.getInt(TOTAL_FRAMES, -1); - } - - @Override - public Configuration getConf() { - return super.getConf(); - } - - @Override - public SequenceRecord nextSequence() { - URI next = locationsIterator.next(); - - List> list; - try (InputStream s = streamCreatorFn.apply(next)){ - list = loadData(null, s); - } catch (IOException e) { - throw new RuntimeException(e); - } - return new org.datavec.api.records.impl.SequenceRecord(list, - new RecordMetaDataURI(next, CodecRecordReader.class)); - } - - @Override - public SequenceRecord loadSequenceFromMetaData(RecordMetaData recordMetaData) throws IOException { - return loadSequenceFromMetaData(Collections.singletonList(recordMetaData)).get(0); - } - - @Override - public List loadSequenceFromMetaData(List recordMetaDatas) throws IOException { - List out = new ArrayList<>(); - for (RecordMetaData meta : recordMetaDatas) { - try (InputStream s = streamCreatorFn.apply(meta.getURI())){ - List> list = loadData(null, s); - out.add(new org.datavec.api.records.impl.SequenceRecord(list, meta)); - } - } - - return out; - } -} diff --git a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java deleted file mode 100644 index 8ef4e8c68..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/CodecRecordReader.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.codec.reader; - -import org.apache.commons.compress.utils.IOUtils; -import org.datavec.api.conf.Configuration; -import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; -import org.datavec.image.loader.ImageLoader; -import org.jcodec.api.FrameGrab; -import org.jcodec.api.JCodecException; -import org.jcodec.common.ByteBufferSeekableByteChannel; -import org.jcodec.common.NIOUtils; -import org.jcodec.common.SeekableByteChannel; - -import java.awt.image.BufferedImage; -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.lang.reflect.Field; -import java.nio.ByteBuffer; -import java.util.ArrayList; -import java.util.List; - -public class CodecRecordReader extends BaseCodecRecordReader { - - private ImageLoader imageLoader; - - @Override - public void setConf(Configuration conf) { - super.setConf(conf); - imageLoader = new ImageLoader(rows, cols); - } - - @Override - protected List> loadData(File file, InputStream inputStream) throws IOException { - SeekableByteChannel seekableByteChannel; - if (inputStream != null) { - //Reading video from DataInputStream: Need data from this stream in a SeekableByteChannel - //Approach used here: load entire video into memory -> ByteBufferSeekableByteChanel - byte[] data = IOUtils.toByteArray(inputStream); - ByteBuffer bb = ByteBuffer.wrap(data); - seekableByteChannel = new FixedByteBufferSeekableByteChannel(bb); - } else { - seekableByteChannel = NIOUtils.readableFileChannel(file); - } - - List> record = new ArrayList<>(); - - if (numFrames >= 1) { - FrameGrab fg; - try { - fg = new FrameGrab(seekableByteChannel); - if (startFrame != 0) - fg.seekToFramePrecise(startFrame); - } catch (JCodecException e) { - throw new RuntimeException(e); - } - - for (int i = startFrame; i < startFrame + numFrames; i++) { - try { - BufferedImage grab = fg.getFrame(); - if (ravel) - record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); - else - record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } else { - if (framesPerSecond < 1) - throw new IllegalStateException("No frames or frame time intervals specified"); - - - else { - for (double i = 0; i < videoLength; i += framesPerSecond) { - try { - BufferedImage grab = FrameGrab.getFrame(seekableByteChannel, i); - if (ravel) - record.add(RecordConverter.toRecord(imageLoader.toRaveledTensor(grab))); - else - record.add(RecordConverter.toRecord(imageLoader.asRowVector(grab))); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - } - - return record; - } - - /** Ugly workaround to a bug in JCodec: https://github.com/jcodec/jcodec/issues/24 */ - private static class FixedByteBufferSeekableByteChannel extends ByteBufferSeekableByteChannel { - private ByteBuffer backing; - - public FixedByteBufferSeekableByteChannel(ByteBuffer backing) { - super(backing); - try { - Field f = this.getClass().getSuperclass().getDeclaredField("maxPos"); - f.setAccessible(true); - f.set(this, backing.limit()); - } catch (Exception e) { - throw new RuntimeException(e); - } - this.backing = backing; - } - - @Override - public int read(ByteBuffer dst) throws IOException { - if (!backing.hasRemaining()) - return -1; - return super.read(dst); - } - } - -} diff --git a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java b/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java deleted file mode 100644 index 16acecdc0..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/main/java/org/datavec/codec/reader/NativeCodecRecordReader.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.codec.reader; - -import org.bytedeco.javacv.FFmpegFrameGrabber; -import org.bytedeco.javacv.Frame; -import org.bytedeco.javacv.OpenCVFrameConverter; -import org.datavec.api.conf.Configuration; -import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; -import org.datavec.image.loader.NativeImageLoader; - -import java.io.File; -import java.io.IOException; -import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; - -public class NativeCodecRecordReader extends BaseCodecRecordReader { - - private OpenCVFrameConverter.ToMat converter; - private NativeImageLoader imageLoader; - - @Override - public void setConf(Configuration conf) { - super.setConf(conf); - converter = new OpenCVFrameConverter.ToMat(); - imageLoader = new NativeImageLoader(rows, cols); - } - - @Override - protected List> loadData(File file, InputStream inputStream) throws IOException { - List> record = new ArrayList<>(); - - try (FFmpegFrameGrabber fg = - inputStream != null ? new FFmpegFrameGrabber(inputStream) : new FFmpegFrameGrabber(file)) { - if (numFrames >= 1) { - fg.start(); - if (startFrame != 0) - fg.setFrameNumber(startFrame); - - for (int i = startFrame; i < startFrame + numFrames; i++) { - Frame grab = fg.grabImage(); - record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); - } - } else { - if (framesPerSecond < 1) - throw new IllegalStateException("No frames or frame time intervals specified"); - else { - fg.start(); - - for (double i = 0; i < videoLength; i += framesPerSecond) { - fg.setTimestamp(Math.round(i * 1000000L)); - Frame grab = fg.grabImage(); - record.add(RecordConverter.toRecord(imageLoader.asRowVector(converter.convert(grab)))); - } - } - } - } - - return record; - } - -} diff --git a/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/AssertTestsExtendBaseClass.java deleted file mode 100644 index 17da1b32e..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.codec.reader; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.codec.reader"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java b/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java deleted file mode 100644 index fff203829..000000000 --- a/datavec/datavec-data/datavec-data-codec/src/test/java/org/datavec/codec/reader/CodecReaderTest.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.codec.reader; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.SequenceRecord; -import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.records.reader.SequenceRecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.ArrayWritable; -import org.datavec.api.writable.Writable; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.common.io.ClassPathResource; - -import java.io.DataInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.util.Iterator; -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -/** - * @author Adam Gibson - */ -public class CodecReaderTest { - @Test - public void testCodecReader() throws Exception { - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new CodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> record = reader.sequenceRecord(); - // System.out.println(record.size()); - - Iterator> it = record.iterator(); - List first = it.next(); - // System.out.println(first); - - //Expected size: 80x46x3 - assertEquals(1, first.size()); - assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); - } - - @Test - public void testCodecReaderMeta() throws Exception { - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new CodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> record = reader.sequenceRecord(); - assertEquals(500, record.size()); //500 frames - - reader.reset(); - SequenceRecord seqR = reader.nextSequence(); - assertEquals(record, seqR.getSequenceRecord()); - RecordMetaData meta = seqR.getMetaData(); - // System.out.println(meta); - assertTrue(meta.getURI().toString().endsWith(file.getName())); - - SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); - assertEquals(seqR, fromMeta); - } - - @Test - public void testViaDataInputStream() throws Exception { - - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new CodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - - Configuration conf2 = new Configuration(conf); - - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> expected = reader.sequenceRecord(); - - - SequenceRecordReader reader2 = new CodecRecordReader(); - reader2.setConf(conf2); - - DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); - List> actual = reader2.sequenceRecord(null, dataInputStream); - - assertEquals(expected, actual); - } - - - @Ignore - @Test - public void testNativeCodecReader() throws Exception { - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new NativeCodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> record = reader.sequenceRecord(); - // System.out.println(record.size()); - - Iterator> it = record.iterator(); - List first = it.next(); - // System.out.println(first); - - //Expected size: 80x46x3 - assertEquals(1, first.size()); - assertEquals(80 * 46 * 3, ((ArrayWritable) first.iterator().next()).length()); - } - - @Ignore - @Test - public void testNativeCodecReaderMeta() throws Exception { - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new NativeCodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> record = reader.sequenceRecord(); - assertEquals(500, record.size()); //500 frames - - reader.reset(); - SequenceRecord seqR = reader.nextSequence(); - assertEquals(record, seqR.getSequenceRecord()); - RecordMetaData meta = seqR.getMetaData(); - // System.out.println(meta); - assertTrue(meta.getURI().toString().endsWith("fire_lowres.mp4")); - - SequenceRecord fromMeta = reader.loadSequenceFromMetaData(meta); - assertEquals(seqR, fromMeta); - } - - @Ignore - @Test - public void testNativeViaDataInputStream() throws Exception { - - File file = new ClassPathResource("datavec-data-codec/fire_lowres.mp4").getFile(); - SequenceRecordReader reader = new NativeCodecRecordReader(); - Configuration conf = new Configuration(); - conf.set(CodecRecordReader.RAVEL, "true"); - conf.set(CodecRecordReader.START_FRAME, "160"); - conf.set(CodecRecordReader.TOTAL_FRAMES, "500"); - conf.set(CodecRecordReader.ROWS, "80"); - conf.set(CodecRecordReader.COLUMNS, "46"); - - Configuration conf2 = new Configuration(conf); - - reader.initialize(new FileSplit(file)); - reader.setConf(conf); - assertTrue(reader.hasNext()); - List> expected = reader.sequenceRecord(); - - - SequenceRecordReader reader2 = new NativeCodecRecordReader(); - reader2.setConf(conf2); - - DataInputStream dataInputStream = new DataInputStream(new FileInputStream(file)); - List> actual = reader2.sequenceRecord(null, dataInputStream); - - assertEquals(expected, actual); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/pom.xml b/datavec/datavec-data/datavec-data-nlp/pom.xml deleted file mode 100644 index 475c13c2e..000000000 --- a/datavec/datavec-data/datavec-data-nlp/pom.xml +++ /dev/null @@ -1,77 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-data - 1.0.0-SNAPSHOT - - - datavec-data-nlp - - datavec-data-nlp - - - 2.0.0 - - - - - org.datavec - datavec-api - - - org.datavec - datavec-local - ${project.version} - test - - - org.apache.commons - commons-lang3 - - - org.cleartk - cleartk-snowball - ${cleartk.version} - - - org.cleartk - cleartk-opennlp-tools - ${cleartk.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java deleted file mode 100644 index 7934a7cba..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/PoStagger.java +++ /dev/null @@ -1,237 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.annotator; - -import opennlp.tools.postag.POSModel; -import opennlp.tools.postag.POSTaggerME; -import opennlp.uima.postag.POSModelResource; -import opennlp.uima.postag.POSModelResourceImpl; -import opennlp.uima.util.AnnotationComboIterator; -import opennlp.uima.util.AnnotationIteratorPair; -import opennlp.uima.util.AnnotatorUtil; -import opennlp.uima.util.UimaUtil; -import org.apache.uima.UimaContext; -import org.apache.uima.analysis_engine.AnalysisEngineDescription; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.cas.CAS; -import org.apache.uima.cas.Feature; -import org.apache.uima.cas.Type; -import org.apache.uima.cas.TypeSystem; -import org.apache.uima.cas.text.AnnotationFS; -import org.apache.uima.fit.component.CasAnnotator_ImplBase; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.fit.factory.ExternalResourceFactory; -import org.apache.uima.resource.ResourceAccessException; -import org.apache.uima.resource.ResourceInitializationException; -import org.apache.uima.util.Level; -import org.apache.uima.util.Logger; -import org.cleartk.token.type.Sentence; -import org.cleartk.token.type.Token; -import org.datavec.nlp.movingwindow.Util; - -import java.util.Iterator; -import java.util.LinkedList; -import java.util.List; - - -public class PoStagger extends CasAnnotator_ImplBase { - - static { - //UIMA logging - Util.disableLogging(); - } - - private POSTaggerME posTagger; - - private Type sentenceType; - - private Type tokenType; - - private Feature posFeature; - - private Feature probabilityFeature; - - private UimaContext context; - - private Logger logger; - - /** - * Initializes a new instance. - * - * Note: Use {@link #initialize(org.apache.uima.UimaContext) } to initialize this instance. Not use the - * constructor. - */ - public PoStagger() { - // must not be implemented ! - } - - /** - * Initializes the current instance with the given context. - * - * Note: Do all initialization in this method, do not use the constructor. - */ - @Override - public void initialize(UimaContext context) throws ResourceInitializationException { - - super.initialize(context); - - this.context = context; - - this.logger = context.getLogger(); - - if (this.logger.isLoggable(Level.INFO)) { - this.logger.log(Level.INFO, "Initializing the OpenNLP " + "Part of Speech annotator."); - } - - POSModel model; - - try { - POSModelResource modelResource = (POSModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); - - model = modelResource.getModel(); - } catch (ResourceAccessException e) { - throw new ResourceInitializationException(e); - } - - Integer beamSize = AnnotatorUtil.getOptionalIntegerParameter(context, UimaUtil.BEAM_SIZE_PARAMETER); - - if (beamSize == null) - beamSize = POSTaggerME.DEFAULT_BEAM_SIZE; - - this.posTagger = new POSTaggerME(model, beamSize, 0); - } - - /** - * Initializes the type system. - */ - @Override - public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { - - // sentence type - this.sentenceType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, - UimaUtil.SENTENCE_TYPE_PARAMETER); - - // token type - this.tokenType = AnnotatorUtil.getRequiredTypeParameter(this.context, typeSystem, - UimaUtil.TOKEN_TYPE_PARAMETER); - - // pos feature - this.posFeature = AnnotatorUtil.getRequiredFeatureParameter(this.context, this.tokenType, - UimaUtil.POS_FEATURE_PARAMETER, CAS.TYPE_NAME_STRING); - - this.probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(this.context, this.tokenType, - UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); - } - - /** - * Performs pos-tagging on the given tcas object. - */ - @Override - public synchronized void process(CAS tcas) { - - final AnnotationComboIterator comboIterator = - new AnnotationComboIterator(tcas, this.sentenceType, this.tokenType); - - for (AnnotationIteratorPair annotationIteratorPair : comboIterator) { - - final List sentenceTokenAnnotationList = new LinkedList(); - - final List sentenceTokenList = new LinkedList(); - - for (AnnotationFS tokenAnnotation : annotationIteratorPair.getSubIterator()) { - - sentenceTokenAnnotationList.add(tokenAnnotation); - - sentenceTokenList.add(tokenAnnotation.getCoveredText()); - } - - final List posTags = this.posTagger.tag(sentenceTokenList); - - double posProbabilities[] = null; - - if (this.probabilityFeature != null) { - posProbabilities = this.posTagger.probs(); - } - - final Iterator posTagIterator = posTags.iterator(); - final Iterator sentenceTokenIterator = sentenceTokenAnnotationList.iterator(); - - int index = 0; - while (posTagIterator.hasNext() && sentenceTokenIterator.hasNext()) { - final String posTag = posTagIterator.next(); - final AnnotationFS tokenAnnotation = sentenceTokenIterator.next(); - - tokenAnnotation.setStringValue(this.posFeature, posTag); - - if (posProbabilities != null) { - tokenAnnotation.setDoubleValue(this.posFeature, posProbabilities[index]); - } - - index++; - } - - // log tokens with pos - if (this.logger.isLoggable(Level.FINER)) { - - final StringBuilder sentenceWithPos = new StringBuilder(); - - sentenceWithPos.append("\""); - - for (final Iterator it = sentenceTokenAnnotationList.iterator(); it.hasNext();) { - final AnnotationFS token = it.next(); - sentenceWithPos.append(token.getCoveredText()); - sentenceWithPos.append('\\'); - sentenceWithPos.append(token.getStringValue(this.posFeature)); - sentenceWithPos.append(' '); - } - - // delete last whitespace - if (sentenceWithPos.length() > 1) // not 0 because it contains already the " char - sentenceWithPos.setLength(sentenceWithPos.length() - 1); - - sentenceWithPos.append("\""); - - this.logger.log(Level.FINER, sentenceWithPos.toString()); - } - } - } - - /** - * Releases allocated resources. - */ - @Override - public void destroy() { - this.posTagger = null; - } - - - public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { - String modelPath = String.format("/models/%s-pos-maxent.bin", languageCode); - return AnalysisEngineFactory.createEngineDescription(PoStagger.class, UimaUtil.MODEL_PARAMETER, - ExternalResourceFactory.createExternalResourceDescription(POSModelResourceImpl.class, - PoStagger.class.getResource(modelPath).toString()), - UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), UimaUtil.TOKEN_TYPE_PARAMETER, - Token.class.getName(), UimaUtil.POS_FEATURE_PARAMETER, "pos"); - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/SentenceAnnotator.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/SentenceAnnotator.java deleted file mode 100644 index 69491b99c..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/SentenceAnnotator.java +++ /dev/null @@ -1,52 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.annotator; - -import org.apache.uima.analysis_engine.AnalysisEngineDescription; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.jcas.JCas; -import org.apache.uima.resource.ResourceInitializationException; -import org.cleartk.util.ParamUtil; -import org.datavec.nlp.movingwindow.Util; - -public class SentenceAnnotator extends org.cleartk.opennlp.tools.SentenceAnnotator { - - static { - //UIMA logging - Util.disableLogging(); - } - - public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { - return AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.class, PARAM_SENTENCE_MODEL_PATH, - ParamUtil.getParameterValue(PARAM_SENTENCE_MODEL_PATH, "/models/en-sent.bin"), - PARAM_WINDOW_CLASS_NAMES, ParamUtil.getParameterValue(PARAM_WINDOW_CLASS_NAMES, null)); - } - - - @Override - public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { - super.process(jCas); - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/StemmerAnnotator.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/StemmerAnnotator.java deleted file mode 100644 index c6d7438d3..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/StemmerAnnotator.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.annotator; - -import org.apache.uima.analysis_engine.AnalysisEngineDescription; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.jcas.JCas; -import org.apache.uima.resource.ResourceInitializationException; -import org.cleartk.snowball.SnowballStemmer; -import org.cleartk.token.type.Token; - - -public class StemmerAnnotator extends SnowballStemmer { - - public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { - return getDescription("English"); - } - - - public static AnalysisEngineDescription getDescription(String language) throws ResourceInitializationException { - return AnalysisEngineFactory.createEngineDescription(StemmerAnnotator.class, SnowballStemmer.PARAM_STEMMER_NAME, - language); - } - - - @SuppressWarnings("unchecked") - @Override - public synchronized void process(JCas jCas) throws AnalysisEngineProcessException { - super.process(jCas); - } - - - - @Override - public void setStem(Token token, String stem) { - token.setStem(stem); - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/TokenizerAnnotator.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/TokenizerAnnotator.java deleted file mode 100644 index a9eef5e7f..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/annotator/TokenizerAnnotator.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.annotator; - - -import opennlp.uima.tokenize.TokenizerModelResourceImpl; -import org.apache.uima.analysis_engine.AnalysisEngineDescription; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.fit.factory.ExternalResourceFactory; -import org.apache.uima.resource.ResourceInitializationException; -import org.cleartk.opennlp.tools.Tokenizer; -import org.cleartk.token.type.Sentence; -import org.cleartk.token.type.Token; -import org.datavec.nlp.movingwindow.Util; -import org.datavec.nlp.tokenization.tokenizer.ConcurrentTokenizer; - - -/** - * Overrides OpenNLP tokenizer to be thread safe - */ -public class TokenizerAnnotator extends Tokenizer { - - static { - //UIMA logging - Util.disableLogging(); - } - - public static AnalysisEngineDescription getDescription(String languageCode) throws ResourceInitializationException { - String modelPath = String.format("/models/%s-token.bin", languageCode); - return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, - opennlp.uima.util.UimaUtil.MODEL_PARAMETER, - ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, - ConcurrentTokenizer.class.getResource(modelPath).toString()), - opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), - opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); - } - - - - public static AnalysisEngineDescription getDescription() throws ResourceInitializationException { - String modelPath = String.format("/models/%s-token.bin", "en"); - return AnalysisEngineFactory.createEngineDescription(ConcurrentTokenizer.class, - opennlp.uima.util.UimaUtil.MODEL_PARAMETER, - ExternalResourceFactory.createExternalResourceDescription(TokenizerModelResourceImpl.class, - ConcurrentTokenizer.class.getResource(modelPath).toString()), - opennlp.uima.util.UimaUtil.SENTENCE_TYPE_PARAMETER, Sentence.class.getName(), - opennlp.uima.util.UimaUtil.TOKEN_TYPE_PARAMETER, Token.class.getName()); - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/input/TextInputFormat.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/input/TextInputFormat.java deleted file mode 100644 index a1f1a3a65..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/input/TextInputFormat.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.input; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.BaseInputFormat; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.nlp.reader.TfidfRecordReader; - -import java.io.IOException; - -/** - * @author Adam Gibson - */ -public class TextInputFormat extends BaseInputFormat { - @Override - public RecordReader createReader(InputSplit split, Configuration conf) throws IOException, InterruptedException { - RecordReader reader = new TfidfRecordReader(); - reader.initialize(conf, split); - return reader; - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java deleted file mode 100644 index 10a05e89d..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/DefaultVocabCache.java +++ /dev/null @@ -1,148 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.metadata; - -import org.nd4j.common.primitives.Counter; -import org.datavec.api.conf.Configuration; -import org.datavec.nlp.vectorizer.TextVectorizer; -import org.nd4j.common.util.MathUtils; -import org.nd4j.common.util.Index; - -public class DefaultVocabCache implements VocabCache { - - private Counter wordFrequencies = new Counter<>(); - private Counter docFrequencies = new Counter<>(); - private int minWordFrequency; - private Index vocabWords = new Index(); - private double numDocs = 0; - - /** - * Instantiate with a given min word frequency - * @param minWordFrequency - */ - public DefaultVocabCache(int minWordFrequency) { - this.minWordFrequency = minWordFrequency; - } - - /* - * Constructor for use with initialize() - */ - public DefaultVocabCache() { - } - - @Override - public void incrementNumDocs(double by) { - numDocs += by; - } - - @Override - public double numDocs() { - return numDocs; - } - - @Override - public String wordAt(int i) { - return vocabWords.get(i).toString(); - } - - @Override - public int wordIndex(String word) { - return vocabWords.indexOf(word); - } - - @Override - public void initialize(Configuration conf) { - minWordFrequency = conf.getInt(TextVectorizer.MIN_WORD_FREQUENCY, 5); - } - - @Override - public double wordFrequency(String word) { - return wordFrequencies.getCount(word); - } - - @Override - public int minWordFrequency() { - return minWordFrequency; - } - - @Override - public Index vocabWords() { - return vocabWords; - } - - @Override - public void incrementDocCount(String word) { - incrementDocCount(word, 1.0); - } - - @Override - public void incrementDocCount(String word, double by) { - docFrequencies.incrementCount(word, by); - - } - - @Override - public void incrementCount(String word) { - incrementCount(word, 1.0); - } - - @Override - public void incrementCount(String word, double by) { - wordFrequencies.incrementCount(word, by); - if (wordFrequencies.getCount(word) >= minWordFrequency && vocabWords.indexOf(word) < 0) - vocabWords.add(word); - } - - @Override - public double idf(String word) { - return docFrequencies.getCount(word); - } - - @Override - public double tfidf(String word, double frequency, boolean smoothIdf) { - double tf = tf((int) frequency); - double docFreq = docFrequencies.getCount(word); - - double idf = idf(numDocs, docFreq, smoothIdf); - double tfidf = MathUtils.tfidf(tf, idf); - return tfidf; - } - - public double idf(double totalDocs, double numTimesWordAppearedInADocument, boolean smooth) { - if(smooth){ - return Math.log((1 + totalDocs) / (1 + numTimesWordAppearedInADocument)) + 1.0; - } else { - return Math.log(totalDocs / numTimesWordAppearedInADocument) + 1.0; - } - } - - public static double tf(int count) { - return count; - } - - public int getMinWordFrequency() { - return minWordFrequency; - } - - public void setMinWordFrequency(int minWordFrequency) { - this.minWordFrequency = minWordFrequency; - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/VocabCache.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/VocabCache.java deleted file mode 100644 index e40628884..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/metadata/VocabCache.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.metadata; - - -import org.datavec.api.conf.Configuration; -import org.nd4j.common.util.Index; - -public interface VocabCache { - - - /** - * Increment the number of documents - * @param by - */ - void incrementNumDocs(double by); - - /** - * Number of documents - * @return the number of documents - */ - double numDocs(); - - /** - * Returns a word in the vocab at a particular index - * @param i the index to get - * @return the word at that index in the vocab - */ - String wordAt(int i); - - int wordIndex(String word); - - /** - * Configuration for initializing - * @param conf the configuration to initialize with - */ - void initialize(Configuration conf); - - /** - * Get the word frequency for a word - * @param word the word to get frequency for - * @return the frequency for a given word - */ - double wordFrequency(String word); - - /** - * The min word frequency - * needed to be included in the vocab - * (default 5) - * @return the min word frequency to - * be included in the vocab - */ - int minWordFrequency(); - - /** - * All of the vocab words (ordered) - * note that these are not all the possible tokens - * @return the list of vocab words - */ - Index vocabWords(); - - - /** - * Increment the doc count for a word by 1 - * @param word the word to increment the count for - */ - void incrementDocCount(String word); - - /** - * Increment the document count for a particular word - * @param word the word to increment the count for - * @param by the amount to increment by - */ - void incrementDocCount(String word, double by); - - /** - * Increment a word count by 1 - * @param word the word to increment the count for - */ - void incrementCount(String word); - - /** - * Increment count for a word - * @param word the word to increment the count for - * @param by the amount to increment by - */ - void incrementCount(String word, double by); - - /** - * Number of documents word has occurred in - * @param word the word to get the idf for - */ - double idf(String word); - - /** - * Calculate the tfidf of the word given the document frequency - * @param word the word to get frequency for - * @param frequency the frequency - * @return the tfidf for a word - */ - double tfidf(String word, double frequency, boolean smoothIdf); - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java deleted file mode 100644 index ee767d22c..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/ContextLabelRetriever.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.movingwindow; - - -import org.apache.commons.lang3.StringUtils; -import org.nd4j.common.base.Preconditions; -import org.nd4j.common.collection.MultiDimensionalMap; -import org.nd4j.common.primitives.Pair; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; - -import java.util.ArrayList; -import java.util.List; - -public class ContextLabelRetriever { - - - private static String BEGIN_LABEL = "<([A-Za-z]+|\\d+)>"; - private static String END_LABEL = ""; - - - private ContextLabelRetriever() {} - - - /** - * Returns a stripped sentence with the indices of words - * with certain kinds of labels. - * - * @param sentence the sentence to process - * @return a pair of a post processed sentence - * with labels stripped and the spans of - * the labels - */ - public static Pair> stringWithLabels(String sentence, - TokenizerFactory tokenizerFactory) { - MultiDimensionalMap map = MultiDimensionalMap.newHashBackedMap(); - Tokenizer t = tokenizerFactory.create(sentence); - List currTokens = new ArrayList<>(); - String currLabel = null; - String endLabel = null; - List>> tokensWithSameLabel = new ArrayList<>(); - while (t.hasMoreTokens()) { - String token = t.nextToken(); - if (token.matches(BEGIN_LABEL)) { - currLabel = token; - - //no labels; add these as NONE and begin the new label - if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("NONE", (List) new ArrayList<>(currTokens))); - currTokens.clear(); - - } - - } else if (token.matches(END_LABEL)) { - if (currLabel == null) - throw new IllegalStateException("Found an ending label with no matching begin label"); - endLabel = token; - } else - currTokens.add(token); - - if (currLabel != null && endLabel != null) { - currLabel = currLabel.replaceAll("[<>/]", ""); - endLabel = endLabel.replaceAll("[<>/]", ""); - Preconditions.checkState(!currLabel.isEmpty(), "Current label is empty!"); - Preconditions.checkState(!endLabel.isEmpty(), "End label is empty!"); - Preconditions.checkState(currLabel.equals(endLabel), "Current label begin and end did not match for the parse. Was: %s ending with %s", - currLabel, endLabel); - - tokensWithSameLabel.add(new Pair<>(currLabel, (List) new ArrayList<>(currTokens))); - currTokens.clear(); - - - //clear out the tokens - currLabel = null; - endLabel = null; - } - - - } - - //no labels; add these as NONE and begin the new label - if (!currTokens.isEmpty()) { - tokensWithSameLabel.add(new Pair<>("none", (List) new ArrayList<>(currTokens))); - currTokens.clear(); - - } - - //now join the output - StringBuilder strippedSentence = new StringBuilder(); - for (Pair> tokensWithLabel : tokensWithSameLabel) { - String joinedSentence = StringUtils.join(tokensWithLabel.getSecond(), " "); - //spaces between separate parts of the sentence - if (!(strippedSentence.length() < 1)) - strippedSentence.append(" "); - strippedSentence.append(joinedSentence); - int begin = strippedSentence.toString().indexOf(joinedSentence); - int end = begin + joinedSentence.length(); - map.put(begin, end, tokensWithLabel.getFirst()); - } - - - return new Pair<>(strippedSentence.toString(), map); - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Util.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Util.java deleted file mode 100644 index 8ba0e5d4a..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Util.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.movingwindow; - - -import org.nd4j.common.primitives.Counter; - -import java.util.List; -import java.util.logging.Level; -import java.util.logging.Logger; - - -public class Util { - - /** - * Returns a thread safe counter - * - * @return - */ - public static Counter parallelCounter() { - return new Counter<>(); - } - - public static boolean matchesAnyStopWord(List stopWords, String word) { - for (String s : stopWords) - if (s.equalsIgnoreCase(word)) - return true; - return false; - } - - public static Level disableLogging() { - Logger logger = Logger.getLogger("org.apache.uima"); - while (logger.getLevel() == null) { - logger = logger.getParent(); - } - Level level = logger.getLevel(); - logger.setLevel(Level.OFF); - return level; - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java deleted file mode 100644 index 929ae743b..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Window.java +++ /dev/null @@ -1,177 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.movingwindow; - -import org.apache.commons.lang3.StringUtils; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - - -public class Window implements Serializable { - /** - * - */ - private static final long serialVersionUID = 6359906393699230579L; - private List words; - private String label = "NONE"; - private boolean beginLabel; - private boolean endLabel; - private int median; - private static String BEGIN_LABEL = "<([A-Z]+|\\d+)>"; - private static String END_LABEL = ""; - private int begin, end; - - /** - * Creates a window with a context of size 3 - * @param words a collection of strings of size 3 - */ - public Window(Collection words, int begin, int end) { - this(words, 5, begin, end); - - } - - public String asTokens() { - return StringUtils.join(words, " "); - } - - - /** - * Initialize a window with the given size - * @param words the words to use - * @param windowSize the size of the window - * @param begin the begin index for the window - * @param end the end index for the window - */ - public Window(Collection words, int windowSize, int begin, int end) { - if (words == null) - throw new IllegalArgumentException("Words must be a list of size 3"); - - this.words = new ArrayList<>(words); - int windowSize1 = windowSize; - this.begin = begin; - this.end = end; - initContext(); - } - - - private void initContext() { - int median = (int) Math.floor(words.size() / 2); - List begin = words.subList(0, median); - List after = words.subList(median + 1, words.size()); - - - for (String s : begin) { - if (s.matches(BEGIN_LABEL)) { - this.label = s.replaceAll("(<|>)", "").replace("/", ""); - beginLabel = true; - } else if (s.matches(END_LABEL)) { - endLabel = true; - this.label = s.replaceAll("(<|>|/)", "").replace("/", ""); - - } - - } - - for (String s1 : after) { - - if (s1.matches(BEGIN_LABEL)) { - this.label = s1.replaceAll("(<|>)", "").replace("/", ""); - beginLabel = true; - } - - if (s1.matches(END_LABEL)) { - endLabel = true; - this.label = s1.replaceAll("(<|>)", ""); - - } - } - this.median = median; - - } - - - - @Override - public String toString() { - return words.toString(); - } - - public List getWords() { - return words; - } - - public void setWords(List words) { - this.words = words; - } - - public String getWord(int i) { - return words.get(i); - } - - public String getFocusWord() { - return words.get(median); - } - - public boolean isBeginLabel() { - return !label.equals("NONE") && beginLabel; - } - - public boolean isEndLabel() { - return !label.equals("NONE") && endLabel; - } - - public String getLabel() { - return label.replace("/", ""); - } - - public int getWindowSize() { - return words.size(); - } - - public int getMedian() { - return median; - } - - public void setLabel(String label) { - this.label = label; - } - - public int getBegin() { - return begin; - } - - public void setBegin(int begin) { - this.begin = begin; - } - - public int getEnd() { - return end; - } - - public void setEnd(int end) { - this.end = end; - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Windows.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Windows.java deleted file mode 100644 index 182d45849..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/movingwindow/Windows.java +++ /dev/null @@ -1,188 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.movingwindow; - - -import org.apache.commons.lang3.StringUtils; -import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; - -import java.io.InputStream; -import java.util.ArrayList; -import java.util.List; -import java.util.StringTokenizer; - -public class Windows { - - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @param windowSize the window size to generate - * @return the list of windows for the tokenized string - */ - public static List windows(InputStream words, int windowSize) { - Tokenizer tokenizer = new DefaultStreamTokenizer(words); - List list = new ArrayList<>(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - return windows(list, windowSize); - } - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @param tokenizerFactory tokenizer factory to use - * @param windowSize the window size to generate - * @return the list of windows for the tokenized string - */ - public static List windows(InputStream words, TokenizerFactory tokenizerFactory, int windowSize) { - Tokenizer tokenizer = tokenizerFactory.create(words); - List list = new ArrayList<>(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - - if (list.isEmpty()) - throw new IllegalStateException("No tokens found for windows"); - - return windows(list, windowSize); - } - - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @param windowSize the window size to generate - * @return the list of windows for the tokenized string - */ - public static List windows(String words, int windowSize) { - StringTokenizer tokenizer = new StringTokenizer(words); - List list = new ArrayList(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - return windows(list, windowSize); - } - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @param tokenizerFactory tokenizer factory to use - * @param windowSize the window size to generate - * @return the list of windows for the tokenized string - */ - public static List windows(String words, TokenizerFactory tokenizerFactory, int windowSize) { - Tokenizer tokenizer = tokenizerFactory.create(words); - List list = new ArrayList<>(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - - if (list.isEmpty()) - throw new IllegalStateException("No tokens found for windows"); - - return windows(list, windowSize); - } - - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @return the list of windows for the tokenized string - */ - public static List windows(String words) { - StringTokenizer tokenizer = new StringTokenizer(words); - List list = new ArrayList(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - return windows(list, 5); - } - - /** - * Constructs a list of window of size windowSize. - * Note that padding for each window is created as well. - * @param words the words to tokenize and construct windows from - * @param tokenizerFactory tokenizer factory to use - * @return the list of windows for the tokenized string - */ - public static List windows(String words, TokenizerFactory tokenizerFactory) { - Tokenizer tokenizer = tokenizerFactory.create(words); - List list = new ArrayList<>(); - while (tokenizer.hasMoreTokens()) - list.add(tokenizer.nextToken()); - return windows(list, 5); - } - - - /** - * Creates a sliding window from text - * @param windowSize the window size to use - * @param wordPos the position of the word to center - * @param sentence the sentence to createComplex a window for - * @return a window based on the given sentence - */ - public static Window windowForWordInPosition(int windowSize, int wordPos, List sentence) { - List window = new ArrayList<>(); - List onlyTokens = new ArrayList<>(); - int contextSize = (int) Math.floor((windowSize - 1) / 2); - - for (int i = wordPos - contextSize; i <= wordPos + contextSize; i++) { - if (i < 0) - window.add(""); - else if (i >= sentence.size()) - window.add(""); - else { - onlyTokens.add(sentence.get(i)); - window.add(sentence.get(i)); - - } - } - - String wholeSentence = StringUtils.join(sentence); - String window2 = StringUtils.join(onlyTokens); - int begin = wholeSentence.indexOf(window2); - int end = begin + window2.length(); - return new Window(window, begin, end); - - } - - - /** - * Constructs a list of window of size windowSize - * @param words the words to construct windows from - * @return the list of windows for the tokenized string - */ - public static List windows(List words, int windowSize) { - - List ret = new ArrayList<>(); - - for (int i = 0; i < words.size(); i++) - ret.add(windowForWordInPosition(windowSize, i, words)); - - - return ret; - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java deleted file mode 100644 index eaed6ed3a..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/reader/TfidfRecordReader.java +++ /dev/null @@ -1,189 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.reader; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; -import org.datavec.api.records.metadata.RecordMetaData; -import org.datavec.api.records.metadata.RecordMetaDataURI; -import org.datavec.api.records.reader.impl.FileRecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.api.vector.Vectorizer; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.datavec.nlp.vectorizer.TfidfVectorizer; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.IOException; -import java.util.*; - -public class TfidfRecordReader extends FileRecordReader { - private TfidfVectorizer tfidfVectorizer; - private List records = new ArrayList<>(); - private Iterator recordIter; - private int numFeatures; - private boolean initialized = false; - - - @Override - public void initialize(InputSplit split) throws IOException, InterruptedException { - initialize(new Configuration(), split); - } - - @Override - public void initialize(Configuration conf, InputSplit split) throws IOException, InterruptedException { - super.initialize(conf, split); - //train a new one since it hasn't been specified - if (tfidfVectorizer == null) { - tfidfVectorizer = new TfidfVectorizer(); - tfidfVectorizer.initialize(conf); - - //clear out old strings - records.clear(); - - INDArray ret = tfidfVectorizer.fitTransform(this, new Vectorizer.RecordCallBack() { - @Override - public void onRecord(Record fullRecord) { - records.add(fullRecord); - } - }); - - //cache the number of features used for each document - numFeatures = ret.columns(); - recordIter = records.iterator(); - } else { - records = new ArrayList<>(); - - //the record reader has 2 phases, we are skipping the - //document frequency phase and just using the super() to get the file contents - //and pass it to the already existing vectorizer. - while (super.hasNext()) { - Record fileContents = super.nextRecord(); - INDArray transform = tfidfVectorizer.transform(fileContents); - - org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( - new ArrayList<>(Collections.singletonList(new NDArrayWritable(transform))), - new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); - - if (appendLabel) - record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); - - records.add(record); - } - - recordIter = records.iterator(); - } - - this.initialized = true; - } - - @Override - public void reset() { - if (inputSplit == null) - throw new UnsupportedOperationException("Cannot reset without first initializing"); - recordIter = records.iterator(); - } - - @Override - public Record nextRecord() { - if (recordIter == null) - return super.nextRecord(); - return recordIter.next(); - } - - @Override - public List next() { - return nextRecord().getRecord(); - } - - @Override - public boolean hasNext() { - //we aren't done vectorizing yet - if (recordIter == null) - return super.hasNext(); - return recordIter.hasNext(); - } - - @Override - public void close() throws IOException { - - } - - @Override - public void setConf(Configuration conf) { - this.conf = conf; - } - - @Override - public Configuration getConf() { - return conf; - } - - public TfidfVectorizer getTfidfVectorizer() { - return tfidfVectorizer; - } - - public void setTfidfVectorizer(TfidfVectorizer tfidfVectorizer) { - if (initialized) { - throw new IllegalArgumentException( - "Setting TfidfVectorizer after TfidfRecordReader initialization doesn't have an effect"); - } - this.tfidfVectorizer = tfidfVectorizer; - } - - public int getNumFeatures() { - return numFeatures; - } - - public void shuffle() { - this.shuffle(new Random()); - } - - public void shuffle(Random random) { - Collections.shuffle(this.records, random); - this.reset(); - } - - @Override - public Record loadFromMetaData(RecordMetaData recordMetaData) throws IOException { - return loadFromMetaData(Collections.singletonList(recordMetaData)).get(0); - } - - @Override - public List loadFromMetaData(List recordMetaDatas) throws IOException { - List out = new ArrayList<>(); - - for (Record fileContents : super.loadFromMetaData(recordMetaDatas)) { - INDArray transform = tfidfVectorizer.transform(fileContents); - - org.datavec.api.records.impl.Record record = new org.datavec.api.records.impl.Record( - new ArrayList<>(Collections.singletonList(new NDArrayWritable(transform))), - new RecordMetaDataURI(fileContents.getMetaData().getURI(), TfidfRecordReader.class)); - - if (appendLabel) - record.getRecord().add(fileContents.getRecord().get(fileContents.getRecord().size() - 1)); - out.add(record); - } - - return out; - } -} - diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/stopwords/StopWords.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/stopwords/StopWords.java deleted file mode 100644 index 189ad6bc9..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/stopwords/StopWords.java +++ /dev/null @@ -1,44 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.stopwords; - -import org.apache.commons.io.IOUtils; - -import java.io.IOException; -import java.util.List; - -public class StopWords { - - private static List stopWords; - - @SuppressWarnings("unchecked") - public static List getStopWords() { - - try { - if (stopWords == null) - stopWords = IOUtils.readLines(StopWords.class.getResourceAsStream("/stopwords")); - } catch (IOException e) { - throw new RuntimeException(e); - } - return stopWords; - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java deleted file mode 100644 index d604aae73..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/ConcurrentTokenizer.java +++ /dev/null @@ -1,125 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - -import opennlp.tools.tokenize.TokenizerME; -import opennlp.tools.tokenize.TokenizerModel; -import opennlp.tools.util.Span; -import opennlp.uima.tokenize.AbstractTokenizer; -import opennlp.uima.tokenize.TokenizerModelResource; -import opennlp.uima.util.AnnotatorUtil; -import opennlp.uima.util.UimaUtil; -import org.apache.uima.UimaContext; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.cas.CAS; -import org.apache.uima.cas.Feature; -import org.apache.uima.cas.TypeSystem; -import org.apache.uima.cas.text.AnnotationFS; -import org.apache.uima.resource.ResourceAccessException; -import org.apache.uima.resource.ResourceInitializationException; - -public class ConcurrentTokenizer extends AbstractTokenizer { - - /** - * The OpenNLP tokenizer. - */ - private TokenizerME tokenizer; - - private Feature probabilityFeature; - - @Override - public synchronized void process(CAS cas) throws AnalysisEngineProcessException { - super.process(cas); - } - - /** - * Initializes a new instance. - * - * Note: Use {@link #initialize(UimaContext) } to initialize - * this instance. Not use the constructor. - */ - public ConcurrentTokenizer() { - super("OpenNLP Tokenizer"); - - // must not be implemented ! - } - - /** - * Initializes the current instance with the given context. - * - * Note: Do all initialization in this method, do not use the constructor. - */ - public void initialize(UimaContext context) throws ResourceInitializationException { - - super.initialize(context); - - TokenizerModel model; - - try { - TokenizerModelResource modelResource = - (TokenizerModelResource) context.getResourceObject(UimaUtil.MODEL_PARAMETER); - - model = modelResource.getModel(); - } catch (ResourceAccessException e) { - throw new ResourceInitializationException(e); - } - - tokenizer = new TokenizerME(model); - } - - /** - * Initializes the type system. - */ - public void typeSystemInit(TypeSystem typeSystem) throws AnalysisEngineProcessException { - - super.typeSystemInit(typeSystem); - - probabilityFeature = AnnotatorUtil.getOptionalFeatureParameter(context, tokenType, - UimaUtil.PROBABILITY_FEATURE_PARAMETER, CAS.TYPE_NAME_DOUBLE); - } - - - @Override - protected Span[] tokenize(CAS cas, AnnotationFS sentence) { - return tokenizer.tokenizePos(sentence.getCoveredText()); - } - - @Override - protected void postProcessAnnotations(Span[] tokens, AnnotationFS[] tokenAnnotations) { - // if interest - if (probabilityFeature != null) { - double tokenProbabilties[] = tokenizer.getTokenProbabilities(); - - for (int i = 0; i < tokenAnnotations.length; i++) { - tokenAnnotations[i].setDoubleValue(probabilityFeature, tokenProbabilties[i]); - } - } - } - - /** - * Releases allocated resources. - */ - public void destroy() { - // dereference model to allow garbage collection - tokenizer = null; - } -} - diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java deleted file mode 100644 index 9f10fb878..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultStreamTokenizer.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - - -import java.io.*; -import java.util.ArrayList; -import java.util.List; - -/** - * Tokenizer based on the {@link java.io.StreamTokenizer} - * @author Adam Gibson - * - */ -public class DefaultStreamTokenizer implements Tokenizer { - - private StreamTokenizer streamTokenizer; - private TokenPreProcess tokenPreProcess; - - - public DefaultStreamTokenizer(InputStream is) { - Reader r = new BufferedReader(new InputStreamReader(is)); - streamTokenizer = new StreamTokenizer(r); - - } - - @Override - public boolean hasMoreTokens() { - if (streamTokenizer.ttype != StreamTokenizer.TT_EOF) { - try { - streamTokenizer.nextToken(); - } catch (IOException e1) { - throw new RuntimeException(e1); - } - } - return streamTokenizer.ttype != StreamTokenizer.TT_EOF && streamTokenizer.ttype != -1; - } - - @Override - public int countTokens() { - return getTokens().size(); - } - - @Override - public String nextToken() { - StringBuilder sb = new StringBuilder(); - - - if (streamTokenizer.ttype == StreamTokenizer.TT_WORD) { - sb.append(streamTokenizer.sval); - } else if (streamTokenizer.ttype == StreamTokenizer.TT_NUMBER) { - sb.append(streamTokenizer.nval); - } else if (streamTokenizer.ttype == StreamTokenizer.TT_EOL) { - try { - while (streamTokenizer.ttype == StreamTokenizer.TT_EOL) - streamTokenizer.nextToken(); - } catch (IOException e) { - throw new RuntimeException(e); - - } - } - - else if (hasMoreTokens()) - return nextToken(); - - - String ret = sb.toString(); - - if (tokenPreProcess != null) - ret = tokenPreProcess.preProcess(ret); - return ret; - - } - - @Override - public List getTokens() { - List tokens = new ArrayList<>(); - while (hasMoreTokens()) { - tokens.add(nextToken()); - } - return tokens; - } - - @Override - public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { - this.tokenPreProcess = tokenPreProcessor; - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java deleted file mode 100644 index f9ba4a0aa..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/DefaultTokenizer.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - -import java.util.ArrayList; -import java.util.List; -import java.util.StringTokenizer; - -/** - * Default tokenizer - * @author Adam Gibson - */ -public class DefaultTokenizer implements Tokenizer { - - public DefaultTokenizer(String tokens) { - tokenizer = new StringTokenizer(tokens); - } - - private StringTokenizer tokenizer; - private TokenPreProcess tokenPreProcess; - - @Override - public boolean hasMoreTokens() { - return tokenizer.hasMoreTokens(); - } - - @Override - public int countTokens() { - return tokenizer.countTokens(); - } - - @Override - public String nextToken() { - String base = tokenizer.nextToken(); - if (tokenPreProcess != null) - base = tokenPreProcess.preProcess(base); - return base; - } - - @Override - public List getTokens() { - List tokens = new ArrayList<>(); - while (hasMoreTokens()) { - tokens.add(nextToken()); - } - return tokens; - } - - @Override - public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { - this.tokenPreProcess = tokenPreProcessor; - - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java deleted file mode 100644 index 2478b7ae3..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/PosUimaTokenizer.java +++ /dev/null @@ -1,136 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - -import org.apache.uima.analysis_engine.AnalysisEngine; -import org.apache.uima.cas.CAS; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.fit.util.JCasUtil; -import org.cleartk.token.type.Sentence; -import org.cleartk.token.type.Token; -import org.datavec.nlp.annotator.PoStagger; -import org.datavec.nlp.annotator.SentenceAnnotator; -import org.datavec.nlp.annotator.StemmerAnnotator; -import org.datavec.nlp.annotator.TokenizerAnnotator; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - -public class PosUimaTokenizer implements Tokenizer { - - private static AnalysisEngine engine; - private List tokens; - private Collection allowedPosTags; - private int index; - private static CAS cas; - - public PosUimaTokenizer(String tokens, AnalysisEngine engine, Collection allowedPosTags) { - if (engine == null) - PosUimaTokenizer.engine = engine; - this.allowedPosTags = allowedPosTags; - this.tokens = new ArrayList<>(); - try { - if (cas == null) - cas = engine.newCAS(); - - cas.reset(); - cas.setDocumentText(tokens); - PosUimaTokenizer.engine.process(cas); - for (Sentence s : JCasUtil.select(cas.getJCas(), Sentence.class)) { - for (Token t : JCasUtil.selectCovered(Token.class, s)) { - //add NONE for each invalid token - if (valid(t)) - if (t.getLemma() != null) - this.tokens.add(t.getLemma()); - else if (t.getStem() != null) - this.tokens.add(t.getStem()); - else - this.tokens.add(t.getCoveredText()); - else - this.tokens.add("NONE"); - } - } - - - - } catch (Exception e) { - throw new RuntimeException(e); - } - - } - - private boolean valid(Token token) { - String check = token.getCoveredText(); - if (check.matches("<[A-Z]+>") || check.matches("")) - return false; - else if (token.getPos() != null && !this.allowedPosTags.contains(token.getPos())) - return false; - return true; - } - - - - @Override - public boolean hasMoreTokens() { - return index < tokens.size(); - } - - @Override - public int countTokens() { - return tokens.size(); - } - - @Override - public String nextToken() { - String ret = tokens.get(index); - index++; - return ret; - } - - @Override - public List getTokens() { - List tokens = new ArrayList(); - while (hasMoreTokens()) { - tokens.add(nextToken()); - } - return tokens; - } - - public static AnalysisEngine defaultAnalysisEngine() { - try { - return AnalysisEngineFactory.createEngine(AnalysisEngineFactory.createEngineDescription( - SentenceAnnotator.getDescription(), TokenizerAnnotator.getDescription(), - PoStagger.getDescription("en"), StemmerAnnotator.getDescription("English"))); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { - // TODO Auto-generated method stub - - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/TokenPreProcess.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/TokenPreProcess.java deleted file mode 100644 index 55412ac77..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/TokenPreProcess.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - - -public interface TokenPreProcess { - - /** - * Pre process a token - * @param token the token to pre process - * @return the preprocessed token - */ - String preProcess(String token); - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/Tokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/Tokenizer.java deleted file mode 100644 index d8f8d2c9a..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/Tokenizer.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - -import java.util.List; - -public interface Tokenizer { - - /** - * An iterator for tracking whether - * more tokens are left in the iterator not - * @return whether there is anymore tokens - * to iterate over - */ - boolean hasMoreTokens(); - - /** - * The number of tokens in the tokenizer - * @return the number of tokens - */ - int countTokens(); - - /** - * The next token (word usually) in the string - * @return the next token in the string if any - */ - String nextToken(); - - /** - * Returns a list of all the tokens - * @return a list of all the tokens - */ - List getTokens(); - - /** - * Set the token pre process - * @param tokenPreProcessor the token pre processor to set - */ - void setTokenPreProcessor(TokenPreProcess tokenPreProcessor); - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java deleted file mode 100644 index 7e430029b..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/UimaTokenizer.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer; - -import org.apache.uima.cas.CAS; -import org.apache.uima.fit.util.JCasUtil; -import org.cleartk.token.type.Token; -import org.datavec.nlp.uima.UimaResource; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; - -/** - * Tokenizer based on the passed in analysis engine - * @author Adam Gibson - * - */ -public class UimaTokenizer implements Tokenizer { - - private List tokens; - private int index; - private static Logger log = LoggerFactory.getLogger(UimaTokenizer.class); - private boolean checkForLabel; - private TokenPreProcess tokenPreProcessor; - - - public UimaTokenizer(String tokens, UimaResource resource, boolean checkForLabel) { - - this.checkForLabel = checkForLabel; - this.tokens = new ArrayList<>(); - try { - CAS cas = resource.process(tokens); - - Collection tokenList = JCasUtil.select(cas.getJCas(), Token.class); - - for (Token t : tokenList) { - - if (!checkForLabel || valid(t.getCoveredText())) - if (t.getLemma() != null) - this.tokens.add(t.getLemma()); - else if (t.getStem() != null) - this.tokens.add(t.getStem()); - else - this.tokens.add(t.getCoveredText()); - } - - - resource.release(cas); - - - } catch (Exception e) { - log.error("",e); - throw new RuntimeException(e); - } - - } - - private boolean valid(String check) { - if (check.matches("<[A-Z]+>") || check.matches("")) - return false; - return true; - } - - - - @Override - public boolean hasMoreTokens() { - return index < tokens.size(); - } - - @Override - public int countTokens() { - return tokens.size(); - } - - @Override - public String nextToken() { - String ret = tokens.get(index); - index++; - if (tokenPreProcessor != null) { - ret = tokenPreProcessor.preProcess(ret); - } - return ret; - } - - @Override - public List getTokens() { - List tokens = new ArrayList<>(); - while (hasMoreTokens()) { - tokens.add(nextToken()); - } - return tokens; - } - - @Override - public void setTokenPreProcessor(TokenPreProcess tokenPreProcessor) { - this.tokenPreProcessor = tokenPreProcessor; - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/EndingPreProcessor.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/EndingPreProcessor.java deleted file mode 100644 index 52b572358..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/EndingPreProcessor.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer.preprocessor; - - -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; - -/** - * Gets rid of endings: - * - * ed,ing, ly, s, . - * @author Adam Gibson - */ -public class EndingPreProcessor implements TokenPreProcess { - @Override - public String preProcess(String token) { - if (token.endsWith("s") && !token.endsWith("ss")) - token = token.substring(0, token.length() - 1); - if (token.endsWith(".")) - token = token.substring(0, token.length() - 1); - if (token.endsWith("ed")) - token = token.substring(0, token.length() - 2); - if (token.endsWith("ing")) - token = token.substring(0, token.length() - 3); - if (token.endsWith("ly")) - token = token.substring(0, token.length() - 2); - return token; - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/LowerCasePreProcessor.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/LowerCasePreProcessor.java deleted file mode 100644 index adb3f322b..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizer/preprocessor/LowerCasePreProcessor.java +++ /dev/null @@ -1,30 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizer.preprocessor; - -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; - -public class LowerCasePreProcessor implements TokenPreProcess { - @Override - public String preProcess(String token) { - return token.toLowerCase(); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/DefaultTokenizerFactory.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/DefaultTokenizerFactory.java deleted file mode 100644 index 45b571afe..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/DefaultTokenizerFactory.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizerfactory; - - - -import org.datavec.nlp.tokenization.tokenizer.DefaultStreamTokenizer; -import org.datavec.nlp.tokenization.tokenizer.DefaultTokenizer; -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; - -import java.io.InputStream; - -/** - * Default tokenizer based on string tokenizer or stream tokenizer - * @author Adam Gibson - */ -public class DefaultTokenizerFactory implements TokenizerFactory { - - private TokenPreProcess tokenPreProcess; - - @Override - public Tokenizer create(String toTokenize) { - DefaultTokenizer t = new DefaultTokenizer(toTokenize); - t.setTokenPreProcessor(tokenPreProcess); - return t; - } - - @Override - public Tokenizer create(InputStream toTokenize) { - Tokenizer t = new DefaultStreamTokenizer(toTokenize); - t.setTokenPreProcessor(tokenPreProcess); - return t; - } - - @Override - public void setTokenPreProcessor(TokenPreProcess preProcessor) { - this.tokenPreProcess = preProcessor; - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java deleted file mode 100644 index 8ef9dce90..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/PosUimaTokenizerFactory.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizerfactory; - - -import org.apache.uima.analysis_engine.AnalysisEngine; -import org.datavec.nlp.annotator.PoStagger; -import org.datavec.nlp.annotator.SentenceAnnotator; -import org.datavec.nlp.annotator.StemmerAnnotator; -import org.datavec.nlp.annotator.TokenizerAnnotator; -import org.datavec.nlp.tokenization.tokenizer.PosUimaTokenizer; -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; - -import java.io.InputStream; -import java.util.Collection; - -import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngine; -import static org.apache.uima.fit.factory.AnalysisEngineFactory.createEngineDescription; - -public class PosUimaTokenizerFactory implements TokenizerFactory { - - private AnalysisEngine tokenizer; - private Collection allowedPoSTags; - private TokenPreProcess tokenPreProcess; - - - public PosUimaTokenizerFactory(Collection allowedPoSTags) { - this(defaultAnalysisEngine(), allowedPoSTags); - } - - public PosUimaTokenizerFactory(AnalysisEngine tokenizer, Collection allowedPosTags) { - this.tokenizer = tokenizer; - this.allowedPoSTags = allowedPosTags; - } - - - public static AnalysisEngine defaultAnalysisEngine() { - try { - return createEngine(createEngineDescription(SentenceAnnotator.getDescription(), - TokenizerAnnotator.getDescription(), PoStagger.getDescription("en"), - StemmerAnnotator.getDescription("English"))); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - - @Override - public Tokenizer create(String toTokenize) { - PosUimaTokenizer t = new PosUimaTokenizer(toTokenize, tokenizer, allowedPoSTags); - t.setTokenPreProcessor(tokenPreProcess); - return t; - } - - @Override - public Tokenizer create(InputStream toTokenize) { - throw new UnsupportedOperationException(); - } - - @Override - public void setTokenPreProcessor(TokenPreProcess preProcessor) { - this.tokenPreProcess = preProcessor; - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/TokenizerFactory.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/TokenizerFactory.java deleted file mode 100644 index ccbb93d98..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/TokenizerFactory.java +++ /dev/null @@ -1,62 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizerfactory; - - - -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.nd4j.shade.jackson.annotation.JsonTypeInfo; - -import java.io.InputStream; - -/** - * Generates a tokenizer for a given string - * @author Adam Gibson - * - */ -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface TokenizerFactory { - - /** - * The tokenizer to createComplex - * @param toTokenize the string to createComplex the tokenizer with - * @return the new tokenizer - */ - Tokenizer create(String toTokenize); - - /** - * Create a tokenizer based on an input stream - * @param toTokenize - * @return - */ - Tokenizer create(InputStream toTokenize); - - /** - * Sets a token pre processor to be used - * with every tokenizer - * @param preProcessor the token pre processor to use - */ - void setTokenPreProcessor(TokenPreProcess preProcessor); - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java deleted file mode 100644 index d92a42d9a..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/tokenization/tokenizerfactory/UimaTokenizerFactory.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.tokenization.tokenizerfactory; - -import org.apache.uima.analysis_engine.AnalysisEngine; -import org.apache.uima.fit.factory.AnalysisEngineFactory; -import org.apache.uima.resource.ResourceInitializationException; -import org.datavec.nlp.annotator.SentenceAnnotator; -import org.datavec.nlp.annotator.TokenizerAnnotator; -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizer.UimaTokenizer; -import org.datavec.nlp.uima.UimaResource; - -import java.io.InputStream; - - -/** - * Uses a uima {@link AnalysisEngine} to - * tokenize text. - * - * - * @author Adam Gibson - * - */ -public class UimaTokenizerFactory implements TokenizerFactory { - - - private UimaResource uimaResource; - private boolean checkForLabel; - private static AnalysisEngine defaultAnalysisEngine; - private TokenPreProcess preProcess; - - public UimaTokenizerFactory() throws ResourceInitializationException { - this(defaultAnalysisEngine(), true); - } - - - public UimaTokenizerFactory(UimaResource resource) { - this(resource, true); - } - - - public UimaTokenizerFactory(AnalysisEngine tokenizer) { - this(tokenizer, true); - } - - - - public UimaTokenizerFactory(UimaResource resource, boolean checkForLabel) { - this.uimaResource = resource; - this.checkForLabel = checkForLabel; - } - - public UimaTokenizerFactory(boolean checkForLabel) throws ResourceInitializationException { - this(defaultAnalysisEngine(), checkForLabel); - } - - - - public UimaTokenizerFactory(AnalysisEngine tokenizer, boolean checkForLabel) { - super(); - this.checkForLabel = checkForLabel; - try { - this.uimaResource = new UimaResource(tokenizer); - - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - - - @Override - public Tokenizer create(String toTokenize) { - if (toTokenize == null || toTokenize.isEmpty()) - throw new IllegalArgumentException("Unable to proceed; on sentence to tokenize"); - Tokenizer ret = new UimaTokenizer(toTokenize, uimaResource, checkForLabel); - ret.setTokenPreProcessor(preProcess); - return ret; - } - - - public UimaResource getUimaResource() { - return uimaResource; - } - - - /** - * Creates a tokenization,/stemming pipeline - * @return a tokenization/stemming pipeline - */ - public static AnalysisEngine defaultAnalysisEngine() { - try { - if (defaultAnalysisEngine == null) - - defaultAnalysisEngine = AnalysisEngineFactory.createEngine( - AnalysisEngineFactory.createEngineDescription(SentenceAnnotator.getDescription(), - TokenizerAnnotator.getDescription())); - - return defaultAnalysisEngine; - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - - @Override - public Tokenizer create(InputStream toTokenize) { - throw new UnsupportedOperationException(); - } - - @Override - public void setTokenPreProcessor(TokenPreProcess preProcessor) { - this.preProcess = preProcessor; - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java deleted file mode 100644 index 058d5b4f3..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BagOfWordsTransform.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import org.datavec.api.transform.Transform; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.shade.jackson.annotation.JsonTypeInfo; - -import java.util.List; - -@JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "@class") -public interface BagOfWordsTransform extends Transform { - - - /** - * The output shape of the transform (usually 1 x number of words) - * @return - */ - long[] outputShape(); - - /** - * The vocab words in the transform. - * This is the words that were accumulated - * when building a vocabulary. - * (This is generally associated with some form of - * mininmum words frequency scanning to build a vocab - * you then map on to a list of vocab words as a list) - * @return the vocab words for the transform - */ - List vocabWords(); - - /** - * Transform for a list of tokens - * that are objects. This is to allow loose - * typing for tokens that are unique (non string) - * @param tokens the token objects to transform - * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) - */ - INDArray transformFromObject(List> tokens); - - - /** - * Transform for a list of tokens - * that are {@link Writable} (Generally {@link org.datavec.api.writable.Text} - * @param tokens the token objects to transform - * @return the output {@link INDArray} (a tokens.size() by {@link #vocabWords()}.size() array) - */ - INDArray transformFrom(List> tokens); - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BaseWordMapTransform.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BaseWordMapTransform.java deleted file mode 100644 index dbba4bb45..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/BaseWordMapTransform.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -public class BaseWordMapTransform { -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java deleted file mode 100644 index 2784ae877..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/GazeteerTransform.java +++ /dev/null @@ -1,146 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.NDArrayMetaData; -import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.jackson.annotation.JsonCreator; -import org.nd4j.shade.jackson.annotation.JsonIgnoreProperties; -import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.util.Collections; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -@Data -@EqualsAndHashCode(callSuper = true) -@JsonInclude(JsonInclude.Include.NON_NULL) -@JsonIgnoreProperties({"gazeteer"}) -public class GazeteerTransform extends BaseColumnTransform implements BagOfWordsTransform { - - private String newColumnName; - private List wordList; - private Set gazeteer; - - @JsonCreator - public GazeteerTransform(@JsonProperty("columnName") String columnName, - @JsonProperty("newColumnName")String newColumnName, - @JsonProperty("wordList") List wordList) { - super(columnName); - this.newColumnName = newColumnName; - this.wordList = wordList; - this.gazeteer = new HashSet<>(wordList); - } - - @Override - public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { - return new NDArrayMetaData(newName,new long[]{wordList.size()}); - } - - @Override - public Writable map(Writable columnWritable) { - throw new UnsupportedOperationException(); - } - - @Override - public Object mapSequence(Object sequence) { - List> sequenceInput = (List>) sequence; - INDArray ret = Nd4j.create(DataType.FLOAT, wordList.size()); - - for(List list : sequenceInput) { - for(Object token : list) { - String s = token.toString(); - if(gazeteer.contains(s)) { - ret.putScalar(wordList.indexOf(s),1); - } - } - } - return ret; - } - - - - @Override - public List> mapSequence(List> sequence) { - INDArray arr = (INDArray) mapSequence((Object) sequence); - return Collections.singletonList(Collections.singletonList(new NDArrayWritable(arr))); - } - - @Override - public String toString() { - return newColumnName; - } - - @Override - public Object map(Object input) { - return gazeteer.contains(input.toString()); - } - - @Override - public String outputColumnName() { - return newColumnName; - } - - @Override - public String[] outputColumnNames() { - return new String[]{newColumnName}; - } - - @Override - public String[] columnNames() { - return new String[]{columnName()}; - } - - @Override - public String columnName() { - return columnName; - } - - @Override - public long[] outputShape() { - return new long[]{wordList.size()}; - } - - @Override - public List vocabWords() { - return wordList; - } - - @Override - public INDArray transformFromObject(List> tokens) { - return (INDArray) mapSequence(tokens); - } - - @Override - public INDArray transformFrom(List> tokens) { - return (INDArray) mapSequence((Object) tokens); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java deleted file mode 100644 index e69f32587..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/MultiNlpTransform.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - - -package org.datavec.nlp.transforms; - -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.NDArrayMetaData; -import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.list.NDArrayList; -import org.nd4j.shade.jackson.annotation.JsonCreator; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.util.Collections; -import java.util.List; - -public class MultiNlpTransform extends BaseColumnTransform implements BagOfWordsTransform { - - private BagOfWordsTransform[] transforms; - private String newColumnName; - private List vocabWords; - - /** - * - * @param columnName - * @param transforms - * @param newColumnName - */ - @JsonCreator - public MultiNlpTransform(@JsonProperty("columnName") String columnName, - @JsonProperty("transforms") BagOfWordsTransform[] transforms, - @JsonProperty("newColumnName") String newColumnName) { - super(columnName); - this.transforms = transforms; - this.vocabWords = transforms[0].vocabWords(); - if(transforms.length > 1) { - for(int i = 1; i < transforms.length; i++) { - if(!transforms[i].vocabWords().equals(vocabWords)) { - throw new IllegalArgumentException("Vocab words not consistent across transforms!"); - } - } - } - - this.newColumnName = newColumnName; - } - - @Override - public Object mapSequence(Object sequence) { - NDArrayList ndArrayList = new NDArrayList(); - for(BagOfWordsTransform bagofWordsTransform : transforms) { - ndArrayList.addAll(new NDArrayList(bagofWordsTransform.transformFromObject((List>) sequence))); - } - - return ndArrayList.array(); - } - - @Override - public List> mapSequence(List> sequence) { - return Collections.singletonList(Collections.singletonList(new NDArrayWritable(transformFrom(sequence)))); - } - - @Override - public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { - return new NDArrayMetaData(newName,outputShape()); - } - - @Override - public Writable map(Writable columnWritable) { - throw new UnsupportedOperationException("Only able to add for time series"); - } - - @Override - public String toString() { - return newColumnName; - } - - @Override - public Object map(Object input) { - throw new UnsupportedOperationException("Only able to add for time series"); - } - - @Override - public long[] outputShape() { - long[] ret = new long[transforms[0].outputShape().length]; - int validatedRank = transforms[0].outputShape().length; - for(int i = 1; i < transforms.length; i++) { - if(transforms[i].outputShape().length != validatedRank) { - throw new IllegalArgumentException("Inconsistent shape length at transform " + i + " , should have been: " + validatedRank); - } - } - for(int i = 0; i < transforms.length; i++) { - for(int j = 0; j < validatedRank; j++) - ret[j] += transforms[i].outputShape()[j]; - } - - return ret; - } - - @Override - public List vocabWords() { - return vocabWords; - } - - @Override - public INDArray transformFromObject(List> tokens) { - NDArrayList ndArrayList = new NDArrayList(); - for(BagOfWordsTransform bagofWordsTransform : transforms) { - INDArray arr2 = bagofWordsTransform.transformFromObject(tokens); - arr2 = arr2.reshape(arr2.length()); - NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); - ndArrayList.addAll(newList); } - - return ndArrayList.array(); - } - - @Override - public INDArray transformFrom(List> tokens) { - NDArrayList ndArrayList = new NDArrayList(); - for(BagOfWordsTransform bagofWordsTransform : transforms) { - INDArray arr2 = bagofWordsTransform.transformFrom(tokens); - arr2 = arr2.reshape(arr2.length()); - NDArrayList newList = new NDArrayList(arr2,(int) arr2.length()); - ndArrayList.addAll(newList); - } - - return ndArrayList.array(); - } - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java deleted file mode 100644 index 9b9483a4e..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransform.java +++ /dev/null @@ -1,226 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import lombok.Data; -import lombok.EqualsAndHashCode; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.NDArrayMetaData; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Counter; -import org.nd4j.common.util.MathUtils; -import org.nd4j.shade.jackson.annotation.JsonCreator; -import org.nd4j.shade.jackson.annotation.JsonInclude; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - -@Data -@EqualsAndHashCode(callSuper = true, exclude = {"tokenizerFactory"}) -@JsonInclude(JsonInclude.Include.NON_NULL) -public class TokenizerBagOfWordsTermSequenceIndexTransform extends BaseColumnTransform { - - private String newColumName; - private Map wordIndexMap; - private Map weightMap; - private boolean exceptionOnUnknown; - private String tokenizerFactoryClass; - private String preprocessorClass; - private TokenizerFactory tokenizerFactory; - - @JsonCreator - public TokenizerBagOfWordsTermSequenceIndexTransform(@JsonProperty("columnName") String columnName, - @JsonProperty("newColumnName") String newColumnName, - @JsonProperty("wordIndexMap") Map wordIndexMap, - @JsonProperty("idfMap") Map idfMap, - @JsonProperty("exceptionOnUnknown") boolean exceptionOnUnknown, - @JsonProperty("tokenizerFactoryClass") String tokenizerFactoryClass, - @JsonProperty("preprocessorClass") String preprocessorClass) { - super(columnName); - this.newColumName = newColumnName; - this.wordIndexMap = wordIndexMap; - this.exceptionOnUnknown = exceptionOnUnknown; - this.weightMap = idfMap; - this.tokenizerFactoryClass = tokenizerFactoryClass; - this.preprocessorClass = preprocessorClass; - if(this.tokenizerFactoryClass == null) { - this.tokenizerFactoryClass = DefaultTokenizerFactory.class.getName(); - } - try { - tokenizerFactory = (TokenizerFactory) Class.forName(this.tokenizerFactoryClass).newInstance(); - } catch (Exception e) { - throw new IllegalStateException("Unable to instantiate tokenizer factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); - } - - if(preprocessorClass != null){ - try { - TokenPreProcess tpp = (TokenPreProcess) Class.forName(this.preprocessorClass).newInstance(); - tokenizerFactory.setTokenPreProcessor(tpp); - } catch (Exception e){ - throw new IllegalStateException("Unable to instantiate preprocessor factory with empty constructor. Does the tokenizer factory class contain a default empty constructor?"); - } - } - - } - - - - @Override - public List map(List writables) { - Text text = (Text) writables.get(inputSchema.getIndexOfColumn(columnName)); - List ret = new ArrayList<>(writables); - ret.set(inputSchema.getIndexOfColumn(columnName),new NDArrayWritable(convert(text.toString()))); - return ret; - } - - @Override - public Object map(Object input) { - return convert(input.toString()); - } - - @Override - public Object mapSequence(Object sequence) { - return convert(sequence.toString()); - } - - @Override - public Schema transform(Schema inputSchema) { - Schema.Builder newSchema = new Schema.Builder(); - for(int i = 0; i < inputSchema.numColumns(); i++) { - if(inputSchema.getName(i).equals(this.columnName)) { - newSchema.addColumnNDArray(newColumName,new long[]{1,wordIndexMap.size()}); - } - else { - newSchema.addColumn(inputSchema.getMetaData(i)); - } - } - - return newSchema.build(); - } - - - /** - * Convert the given text - * in to an {@link INDArray} - * using the {@link TokenizerFactory} - * specified in the constructor. - * @param text the text to transform - * @return the created {@link INDArray} - * based on the {@link #wordIndexMap} for the column indices - * of the word. - */ - public INDArray convert(String text) { - Tokenizer tokenizer = tokenizerFactory.create(text); - List tokens = tokenizer.getTokens(); - INDArray create = Nd4j.create(1,wordIndexMap.size()); - Counter tokenizedCounter = new Counter<>(); - - for(int i = 0; i < tokens.size(); i++) { - tokenizedCounter.incrementCount(tokens.get(i),1.0); - } - - for(int i = 0; i < tokens.size(); i++) { - if(wordIndexMap.containsKey(tokens.get(i))) { - int idx = wordIndexMap.get(tokens.get(i)); - int count = (int) tokenizedCounter.getCount(tokens.get(i)); - double weight = tfidfWord(tokens.get(i),count,tokens.size()); - create.putScalar(idx,weight); - } - } - - return create; - } - - - /** - * Calculate the tifdf for a word - * given the word, word count, and document length - * @param word the word to calculate - * @param wordCount the word frequency - * @param documentLength the number of words in the document - * @return the tfidf weight for a given word - */ - public double tfidfWord(String word, long wordCount, long documentLength) { - double tf = tfForWord(wordCount, documentLength); - double idf = idfForWord(word); - return MathUtils.tfidf(tf, idf); - } - - /** - * Calculate the weight term frequency for a given - * word normalized by the dcoument length - * @param wordCount the word frequency - * @param documentLength the number of words in the edocument - * @return - */ - private double tfForWord(long wordCount, long documentLength) { - return wordCount; - } - - private double idfForWord(String word) { - if(weightMap.containsKey(word)) - return weightMap.get(word); - return 0; - } - - - @Override - public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { - return new NDArrayMetaData(outputColumnName(),new long[]{1,wordIndexMap.size()}); - } - - @Override - public String outputColumnName() { - return newColumName; - } - - @Override - public String[] outputColumnNames() { - return new String[]{newColumName}; - } - - @Override - public String[] columnNames() { - return new String[]{columnName()}; - } - - @Override - public String columnName() { - return columnName; - } - - @Override - public Writable map(Writable columnWritable) { - return new NDArrayWritable(convert(columnWritable.toString())); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/uima/UimaResource.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/uima/UimaResource.java deleted file mode 100644 index 5ba3a60b9..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/uima/UimaResource.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.uima; - -import org.apache.uima.analysis_engine.AnalysisEngine; -import org.apache.uima.analysis_engine.AnalysisEngineProcessException; -import org.apache.uima.cas.CAS; -import org.apache.uima.resource.ResourceInitializationException; -import org.apache.uima.util.CasPool; - -public class UimaResource { - - private AnalysisEngine analysisEngine; - private CasPool casPool; - - public UimaResource(AnalysisEngine analysisEngine) throws ResourceInitializationException { - this.analysisEngine = analysisEngine; - this.casPool = new CasPool(Runtime.getRuntime().availableProcessors() * 10, analysisEngine); - - } - - public UimaResource(AnalysisEngine analysisEngine, CasPool casPool) { - this.analysisEngine = analysisEngine; - this.casPool = casPool; - - } - - - public AnalysisEngine getAnalysisEngine() { - return analysisEngine; - } - - - public void setAnalysisEngine(AnalysisEngine analysisEngine) { - this.analysisEngine = analysisEngine; - } - - - public CasPool getCasPool() { - return casPool; - } - - - public void setCasPool(CasPool casPool) { - this.casPool = casPool; - } - - - /** - * Use the given analysis engine and process the given text - * You must release the return cas yourself - * @param text the text to rpocess - * @return the processed cas - */ - public CAS process(String text) { - CAS cas = retrieve(); - - cas.setDocumentText(text); - try { - analysisEngine.process(cas); - } catch (AnalysisEngineProcessException e) { - if (text != null && !text.isEmpty()) - return process(text); - throw new RuntimeException(e); - } - - return cas; - - - } - - - public CAS retrieve() { - CAS ret = casPool.getCas(); - try { - return ret == null ? analysisEngine.newCAS() : ret; - } catch (ResourceInitializationException e) { - throw new RuntimeException(e); - } - } - - - public void release(CAS cas) { - casPool.releaseCas(cas); - } - - - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java deleted file mode 100644 index 4988f0c3f..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/AbstractTfidfVectorizer.java +++ /dev/null @@ -1,77 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.vectorizer; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.nlp.tokenization.tokenizer.TokenPreProcess; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; - -import java.util.HashSet; -import java.util.Set; - -public abstract class AbstractTfidfVectorizer extends TextVectorizer { - - @Override - public void doWithTokens(Tokenizer tokenizer) { - Set seen = new HashSet<>(); - while (tokenizer.hasMoreTokens()) { - String token = tokenizer.nextToken(); - if (!stopWords.contains(token)) { - cache.incrementCount(token); - if (!seen.contains(token)) { - cache.incrementDocCount(token); - } - seen.add(token); - } - } - } - - @Override - public TokenizerFactory createTokenizerFactory(Configuration conf) { - String clazz = conf.get(TOKENIZER, DefaultTokenizerFactory.class.getName()); - try { - Class tokenizerFactoryClazz = - (Class) Class.forName(clazz); - TokenizerFactory tf = tokenizerFactoryClazz.newInstance(); - String preproc = conf.get(PREPROCESSOR, null); - if(preproc != null){ - TokenPreProcess tpp = (TokenPreProcess) Class.forName(preproc).newInstance(); - tf.setTokenPreProcessor(tpp); - } - return tf; - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public abstract VECTOR_TYPE createVector(Object[] args); - - @Override - public abstract VECTOR_TYPE fitTransform(RecordReader reader); - - @Override - public abstract VECTOR_TYPE transform(Record record); -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java deleted file mode 100644 index 98e2fea4d..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TextVectorizer.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.vectorizer; - -import lombok.Getter; -import org.nd4j.common.primitives.Counter; -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.vector.Vectorizer; -import org.datavec.api.writable.Writable; -import org.datavec.nlp.metadata.DefaultVocabCache; -import org.datavec.nlp.metadata.VocabCache; -import org.datavec.nlp.stopwords.StopWords; -import org.datavec.nlp.tokenization.tokenizer.Tokenizer; -import org.datavec.nlp.tokenization.tokenizerfactory.TokenizerFactory; - -import java.util.Collection; - -public abstract class TextVectorizer implements Vectorizer { - - protected TokenizerFactory tokenizerFactory; - protected int minWordFrequency = 0; - public final static String MIN_WORD_FREQUENCY = "org.nd4j.nlp.minwordfrequency"; - public final static String STOP_WORDS = "org.nd4j.nlp.stopwords"; - public final static String TOKENIZER = "org.datavec.nlp.tokenizerfactory"; - public static final String PREPROCESSOR = "org.datavec.nlp.preprocessor"; - public final static String VOCAB_CACHE = "org.datavec.nlp.vocabcache"; - protected Collection stopWords; - @Getter - protected VocabCache cache; - - @Override - public void initialize(Configuration conf) { - tokenizerFactory = createTokenizerFactory(conf); - minWordFrequency = conf.getInt(MIN_WORD_FREQUENCY, 5); - if(conf.get(STOP_WORDS) != null) - stopWords = conf.getStringCollection(STOP_WORDS); - if (stopWords == null) - stopWords = StopWords.getStopWords(); - - String clazz = conf.get(VOCAB_CACHE, DefaultVocabCache.class.getName()); - try { - Class tokenizerFactoryClazz = (Class) Class.forName(clazz); - cache = tokenizerFactoryClazz.newInstance(); - cache.initialize(conf); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public void fit(RecordReader reader) { - fit(reader, null); - } - - @Override - public void fit(RecordReader reader, RecordCallBack callBack) { - while (reader.hasNext()) { - Record record = reader.nextRecord(); - String s = toString(record.getRecord()); - Tokenizer tokenizer = tokenizerFactory.create(s); - doWithTokens(tokenizer); - if (callBack != null) - callBack.onRecord(record); - cache.incrementNumDocs(1); - } - } - - - protected Counter wordFrequenciesForRecord(Collection record) { - String s = toString(record); - Tokenizer tokenizer = tokenizerFactory.create(s); - Counter ret = new Counter<>(); - while (tokenizer.hasMoreTokens()) - ret.incrementCount(tokenizer.nextToken(), 1.0); - return ret; - } - - - protected String toString(Collection record) { - StringBuilder sb = new StringBuilder(); - for(Writable w : record){ - sb.append(w.toString()); - } - return sb.toString(); - } - - - /** - * Increment counts, add to collection,... - * @param tokenizer - */ - public abstract void doWithTokens(Tokenizer tokenizer); - - /** - * Create tokenizer factory based on the configuration - * @param conf the configuration to use - * @return the tokenizer factory based on the configuration - */ - public abstract TokenizerFactory createTokenizerFactory(Configuration conf); - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java b/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java deleted file mode 100644 index 9a2f2db9e..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/java/org/datavec/nlp/vectorizer/TfidfVectorizer.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.vectorizer; - - -import org.datavec.api.conf.Configuration; -import org.nd4j.common.primitives.Counter; -import org.datavec.api.records.Record; -import org.datavec.api.records.metadata.RecordMetaDataURI; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.NDArrayWritable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class TfidfVectorizer extends AbstractTfidfVectorizer { - /** - * Default: True.
- * If true: use idf(d, t) = log [ (1 + n) / (1 + df(d, t)) ] + 1
- * If false: use idf(t) = log [ n / df(t) ] + 1
- */ - public static final String SMOOTH_IDF = "org.datavec.nlp.TfidfVectorizer.smooth_idf"; - - protected boolean smooth_idf; - - @Override - public INDArray createVector(Object[] args) { - Counter docFrequencies = (Counter) args[0]; - double[] vector = new double[cache.vocabWords().size()]; - for (int i = 0; i < cache.vocabWords().size(); i++) { - String word = cache.wordAt(i); - double freq = docFrequencies.getCount(word); - vector[i] = cache.tfidf(word, freq, smooth_idf); - } - return Nd4j.create(vector); - } - - @Override - public INDArray fitTransform(RecordReader reader) { - return fitTransform(reader, null); - } - - @Override - public INDArray fitTransform(final RecordReader reader, RecordCallBack callBack) { - final List records = new ArrayList<>(); - fit(reader, new RecordCallBack() { - @Override - public void onRecord(Record record) { - records.add(record); - } - }); - - if (records.isEmpty()) - throw new IllegalStateException("No records found!"); - INDArray ret = Nd4j.create(records.size(), cache.vocabWords().size()); - int i = 0; - for (Record record : records) { - INDArray transformed = transform(record); - org.datavec.api.records.impl.Record transformedRecord = new org.datavec.api.records.impl.Record( - Arrays.asList(new NDArrayWritable(transformed), - record.getRecord().get(record.getRecord().size() - 1)), - new RecordMetaDataURI(record.getMetaData().getURI(), reader.getClass())); - ret.putRow(i++, transformed); - if (callBack != null) { - callBack.onRecord(transformedRecord); - } - } - - return ret; - } - - @Override - public INDArray transform(Record record) { - Counter wordFrequencies = wordFrequenciesForRecord(record.getRecord()); - return createVector(new Object[] {wordFrequencies}); - } - - - @Override - public void initialize(Configuration conf){ - super.initialize(conf); - this.smooth_idf = conf.getBoolean(SMOOTH_IDF, true); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/main/resources/stopwords b/datavec/datavec-data/datavec-data-nlp/src/main/resources/stopwords deleted file mode 100644 index f64dfcc52..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/main/resources/stopwords +++ /dev/null @@ -1,194 +0,0 @@ -a -----s -act -"the -"The -about -above -after -again -against -all -am -an -and -any -are -aren't -as -at -be -because -been -before -being -below -between -both -but -by -can't -cannot -could -couldn't -did -didn't -do -does -doesn't -doing -don't -down -during -each -few -for -from -further -had -hadn't -has -hasn't -have -haven't -having -he -he'd -he'll -he's -her -here -here's -hers -herself -him -himself -his -how -how's -i -i'd -i'll -i'm -i've -if -in -into -is -isn't -it -it's -its -itself -let's -me -more -most -mustn't -my -myself -no -nor -not -of -off -on -once -only -or -other -ought -our -ours -ourselves -out -over -own -put -same -shan't -she -she'd -she'll -she's -should -somebody -something -shouldn't -so -some -such -take -than -that -that's -the -their -theirs -them -themselves -then -there -there's -these -they -they'd -they'll -they're -they've -this -those -through -to -too -under -until -up -very -was -wasn't -we -we'd -we'll -we're -we've -were -weren't -what -what's -when -when's -where -where's -which -while -who -who's -whom -why -why's -will -with -without -won't -would -wouldn't -you -you'd -you'll -you're -you've -your -yours -yourself -yourselves -. -? -! -, -+ -= -also -- -; -: diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/AssertTestsExtendBaseClass.java deleted file mode 100644 index 9c343f702..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.nlp; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.nlp"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java b/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java deleted file mode 100644 index 2ae8e684a..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/reader/TfidfRecordReaderTest.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.reader; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.Record; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.CollectionInputSplit; -import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.datavec.nlp.vectorizer.TfidfVectorizer; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.common.io.ClassPathResource; - -import java.io.File; -import java.net.URI; -import java.util.*; - -import static org.junit.Assert.*; - -/** - * @author Adam Gibson - */ -public class TfidfRecordReaderTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void testReader() throws Exception { - TfidfVectorizer vectorizer = new TfidfVectorizer(); - Configuration conf = new Configuration(); - conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); - conf.setBoolean(RecordReader.APPEND_LABEL, true); - vectorizer.initialize(conf); - TfidfRecordReader reader = new TfidfRecordReader(); - File f = testDir.newFolder(); - new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); - List u = new ArrayList<>(); - for(File f2 : f.listFiles()){ - if(f2.isDirectory()){ - for(File f3 : f2.listFiles()){ - u.add(f3.toURI()); - } - } else { - u.add(f2.toURI()); - } - } - Collections.sort(u); - CollectionInputSplit c = new CollectionInputSplit(u); - reader.initialize(conf, c); - int count = 0; - int[] labelAssertions = new int[3]; - while (reader.hasNext()) { - Collection record = reader.next(); - Iterator recordIter = record.iterator(); - NDArrayWritable writable = (NDArrayWritable) recordIter.next(); - labelAssertions[count] = recordIter.next().toInt(); - count++; - } - - assertArrayEquals(new int[] {0, 1, 2}, labelAssertions); - assertEquals(3, reader.getLabels().size()); - assertEquals(3, count); - } - - @Test - public void testRecordMetaData() throws Exception { - TfidfVectorizer vectorizer = new TfidfVectorizer(); - Configuration conf = new Configuration(); - conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); - conf.setBoolean(RecordReader.APPEND_LABEL, true); - vectorizer.initialize(conf); - TfidfRecordReader reader = new TfidfRecordReader(); - File f = testDir.newFolder(); - new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); - reader.initialize(conf, new FileSplit(f)); - - while (reader.hasNext()) { - Record record = reader.nextRecord(); - assertNotNull(record.getMetaData().getURI()); - assertEquals(record.getMetaData().getReaderClass(), TfidfRecordReader.class); - } - } - - - @Test - public void testReadRecordFromMetaData() throws Exception { - TfidfVectorizer vectorizer = new TfidfVectorizer(); - Configuration conf = new Configuration(); - conf.setInt(TfidfVectorizer.MIN_WORD_FREQUENCY, 1); - conf.setBoolean(RecordReader.APPEND_LABEL, true); - vectorizer.initialize(conf); - TfidfRecordReader reader = new TfidfRecordReader(); - File f = testDir.newFolder(); - new ClassPathResource("datavec-data-nlp/labeled/").copyDirectory(f); - reader.initialize(conf, new FileSplit(f)); - - Record record = reader.nextRecord(); - - Record reread = reader.loadFromMetaData(record.getMetaData()); - - assertEquals(record.getRecord().size(), 2); - assertEquals(reread.getRecord().size(), 2); - assertEquals(record.getRecord().get(0), reread.getRecord().get(0)); - assertEquals(record.getRecord().get(1), reread.getRecord().get(1)); - assertEquals(record.getMetaData(), reread.getMetaData()); - } -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java b/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java deleted file mode 100644 index 6f567d1af..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestGazeteerTransform.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -public class TestGazeteerTransform { - - @Test - public void testGazeteerTransform(){ - - String[] corpus = { - "hello I like apple".toLowerCase(), - "cherry date eggplant potato".toLowerCase() - }; - - //Gazeteer transform: basically 0/1 if word is present. Assumes already tokenized input - List words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); - - GazeteerTransform t = new GazeteerTransform("words", "out", words); - - SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() - .addColumnString("words").build(); - - TransformProcess tp = new TransformProcess.Builder(schema) - .transform(t) - .build(); - - List>> input = new ArrayList<>(); - for(String s : corpus){ - String[] split = s.split(" "); - List> seq = new ArrayList<>(); - for(String s2 : split){ - seq.add(Collections.singletonList(new Text(s2))); - } - input.add(seq); - } - - List>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); - - INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); - INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); - - INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0}); - INDArray exp1 = Nd4j.create(new float[]{0, 0, 1, 1, 1}); - - assertEquals(exp0, arr0); - assertEquals(exp1, arr1); - - - String json = tp.toJson(); - TransformProcess tp2 = TransformProcess.fromJson(json); - assertEquals(tp, tp2); - - List>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); - INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); - INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); - - assertEquals(exp0, arr0a); - assertEquals(exp1, arr1a); - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java b/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java deleted file mode 100644 index a2194d6f9..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TestMultiNLPTransform.java +++ /dev/null @@ -1,96 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.local.transforms.LocalTransformExecutor; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.*; - -import static org.junit.Assert.assertEquals; - -public class TestMultiNLPTransform { - - @Test - public void test(){ - - List words = Arrays.asList("apple", "banana", "cherry", "date", "eggplant"); - GazeteerTransform t1 = new GazeteerTransform("words", "out", words); - GazeteerTransform t2 = new GazeteerTransform("out", "out", words); - - - MultiNlpTransform multi = new MultiNlpTransform("text", new BagOfWordsTransform[]{t1, t2}, "out"); - - String[] corpus = { - "hello I like apple".toLowerCase(), - "date eggplant potato".toLowerCase() - }; - - List>> input = new ArrayList<>(); - for(String s : corpus){ - String[] split = s.split(" "); - List> seq = new ArrayList<>(); - for(String s2 : split){ - seq.add(Collections.singletonList(new Text(s2))); - } - input.add(seq); - } - - SequenceSchema schema = (SequenceSchema) new SequenceSchema.Builder() - .addColumnString("text").build(); - - TransformProcess tp = new TransformProcess.Builder(schema) - .transform(multi) - .build(); - - List>> execute = LocalTransformExecutor.executeSequenceToSequence(input, tp); - - INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); - INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); - - INDArray exp0 = Nd4j.create(new float[]{1, 0, 0, 0, 0, 1, 0, 0, 0, 0}); - INDArray exp1 = Nd4j.create(new float[]{0, 0, 0, 1, 1, 0, 0, 0, 1, 1}); - - assertEquals(exp0, arr0); - assertEquals(exp1, arr1); - - - String json = tp.toJson(); - TransformProcess tp2 = TransformProcess.fromJson(json); - assertEquals(tp, tp2); - - List>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, tp); - INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); - INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); - - assertEquals(exp0, arr0a); - assertEquals(exp1, arr1a); - - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java b/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java deleted file mode 100644 index 3d16997da..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/java/org/datavec/nlp/transforms/TokenizerBagOfWordsTermSequenceIndexTransformTest.java +++ /dev/null @@ -1,414 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.nlp.transforms; - -import org.datavec.api.conf.Configuration; -import org.datavec.api.records.reader.impl.collection.CollectionRecordReader; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.SequenceSchema; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.local.transforms.LocalTransformExecutor; -import org.datavec.nlp.metadata.VocabCache; -import org.datavec.nlp.tokenization.tokenizer.preprocessor.LowerCasePreProcessor; -import org.datavec.nlp.tokenization.tokenizerfactory.DefaultTokenizerFactory; -import org.datavec.nlp.vectorizer.TfidfVectorizer; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Triple; - -import java.util.*; - -import static org.datavec.nlp.vectorizer.TextVectorizer.*; -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; - -public class TokenizerBagOfWordsTermSequenceIndexTransformTest { - - @Test - public void testSequenceExecution() { - //credit: https://stackoverflow.com/questions/23792781/tf-idf-feature-weights-using-sklearn-feature-extraction-text-tfidfvectorizer - String[] corpus = { - "This is very strange".toLowerCase(), - "This is very nice".toLowerCase() - }; - //{'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} - - /* - ## Reproduce with: - from sklearn.feature_extraction.text import TfidfVectorizer - corpus = ["This is very strange", "This is very nice"] - - ## SMOOTH = FALSE case: - vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) - X = vectorizer.fit_transform(corpus) - idf = vectorizer.idf_ - print(dict(zip(vectorizer.get_feature_names(), idf))) - - newText = ["This is very strange", "This is very nice"] - out = vectorizer.transform(newText) - print(out) - - {'is': 1.0, 'nice': 1.6931471805599454, 'strange': 1.6931471805599454, 'this': 1.0, 'very': 1.0} - (0, 4) 1.0 - (0, 3) 1.0 - (0, 2) 1.6931471805599454 - (0, 0) 1.0 - (1, 4) 1.0 - (1, 3) 1.0 - (1, 1) 1.6931471805599454 - (1, 0) 1.0 - - ## SMOOTH + TRUE case: - {'is': 1.0, 'nice': 1.4054651081081644, 'strange': 1.4054651081081644, 'this': 1.0, 'very': 1.0} - (0, 4) 1.0 - (0, 3) 1.0 - (0, 2) 1.4054651081081644 - (0, 0) 1.0 - (1, 4) 1.0 - (1, 3) 1.0 - (1, 1) 1.4054651081081644 - (1, 0) 1.0 - */ - - List>> input = new ArrayList<>(); - input.add(Arrays.asList(Arrays.asList(new Text(corpus[0])),Arrays.asList(new Text(corpus[1])))); - - // First: Check TfidfVectorizer vs. scikit: - - Map idfMapNoSmooth = new HashMap<>(); - idfMapNoSmooth.put("is",1.0); - idfMapNoSmooth.put("nice",1.6931471805599454); - idfMapNoSmooth.put("strange",1.6931471805599454); - idfMapNoSmooth.put("this",1.0); - idfMapNoSmooth.put("very",1.0); - - Map idfMapSmooth = new HashMap<>(); - idfMapSmooth.put("is",1.0); - idfMapSmooth.put("nice",1.4054651081081644); - idfMapSmooth.put("strange",1.4054651081081644); - idfMapSmooth.put("this",1.0); - idfMapSmooth.put("very",1.0); - - - - TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); - Configuration configuration = new Configuration(); - configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); - configuration.set(MIN_WORD_FREQUENCY,"1"); - configuration.set(STOP_WORDS,""); - configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); - - tfidfVectorizer.initialize(configuration); - - CollectionRecordReader collectionRecordReader = new CollectionRecordReader(input.get(0)); - INDArray array = tfidfVectorizer.fitTransform(collectionRecordReader); - - INDArray expNoSmooth = Nd4j.create(DataType.FLOAT, 2, 5); - VocabCache vc = tfidfVectorizer.getCache(); - expNoSmooth.putScalar(0, vc.wordIndex("very"), 1.0); - expNoSmooth.putScalar(0, vc.wordIndex("this"), 1.0); - expNoSmooth.putScalar(0, vc.wordIndex("strange"), 1.6931471805599454); - expNoSmooth.putScalar(0, vc.wordIndex("is"), 1.0); - - expNoSmooth.putScalar(1, vc.wordIndex("very"), 1.0); - expNoSmooth.putScalar(1, vc.wordIndex("this"), 1.0); - expNoSmooth.putScalar(1, vc.wordIndex("nice"), 1.6931471805599454); - expNoSmooth.putScalar(1, vc.wordIndex("is"), 1.0); - - assertEquals(expNoSmooth, array); - - - //------------------------------------------------------------ - //Smooth version: - tfidfVectorizer = new TfidfVectorizer(); - configuration = new Configuration(); - configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); - configuration.set(MIN_WORD_FREQUENCY,"1"); - configuration.set(STOP_WORDS,""); - configuration.set(TfidfVectorizer.SMOOTH_IDF, "true"); - - tfidfVectorizer.initialize(configuration); - - collectionRecordReader.reset(); - array = tfidfVectorizer.fitTransform(collectionRecordReader); - - INDArray expSmooth = Nd4j.create(DataType.FLOAT, 2, 5); - expSmooth.putScalar(0, vc.wordIndex("very"), 1.0); - expSmooth.putScalar(0, vc.wordIndex("this"), 1.0); - expSmooth.putScalar(0, vc.wordIndex("strange"), 1.4054651081081644); - expSmooth.putScalar(0, vc.wordIndex("is"), 1.0); - - expSmooth.putScalar(1, vc.wordIndex("very"), 1.0); - expSmooth.putScalar(1, vc.wordIndex("this"), 1.0); - expSmooth.putScalar(1, vc.wordIndex("nice"), 1.4054651081081644); - expSmooth.putScalar(1, vc.wordIndex("is"), 1.0); - - assertEquals(expSmooth, array); - - - ////////////////////////////////////////////////////////// - - //Second: Check transform vs scikit/TfidfVectorizer - - List vocab = new ArrayList<>(5); //Arrays.asList("is","nice","strange","this","very"); - for( int i=0; i<5; i++ ){ - vocab.add(vc.wordAt(i)); - } - - String inputColumnName = "input"; - String outputColumnName = "output"; - Map wordIndexMap = new HashMap<>(); - for(int i = 0; i < vocab.size(); i++) { - wordIndexMap.put(vocab.get(i),i); - } - - TokenizerBagOfWordsTermSequenceIndexTransform tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( - inputColumnName, - outputColumnName, - wordIndexMap, - idfMapNoSmooth, - false, - null, null); - - SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); - sequenceSchemaBuilder.addColumnString("input"); - SequenceSchema schema = sequenceSchemaBuilder.build(); - assertEquals("input",schema.getName(0)); - - TransformProcess transformProcess = new TransformProcess.Builder(schema) - .transform(tokenizerBagOfWordsTermSequenceIndexTransform) - .build(); - - List>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); - - - - //System.out.println(execute); - INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); - INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); - - assertEquals(expNoSmooth.getRow(0, true), arr0); - assertEquals(expNoSmooth.getRow(1, true), arr1); - - - //-------------------------------- - //Check smooth: - - tokenizerBagOfWordsTermSequenceIndexTransform = new TokenizerBagOfWordsTermSequenceIndexTransform( - inputColumnName, - outputColumnName, - wordIndexMap, - idfMapSmooth, - false, - null, null); - - schema = (SequenceSchema) new SequenceSchema.Builder().addColumnString("input").build(); - - transformProcess = new TransformProcess.Builder(schema) - .transform(tokenizerBagOfWordsTermSequenceIndexTransform) - .build(); - - execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); - - arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); - arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); - - assertEquals(expSmooth.getRow(0, true), arr0); - assertEquals(expSmooth.getRow(1, true), arr1); - - - - //Test JSON serialization: - - String json = transformProcess.toJson(); - TransformProcess fromJson = TransformProcess.fromJson(json); - assertEquals(transformProcess, fromJson); - List>> execute2 = LocalTransformExecutor.executeSequenceToSequence(input, fromJson); - - INDArray arr0a = ((NDArrayWritable)execute2.get(0).get(0).get(0)).get(); - INDArray arr1a = ((NDArrayWritable)execute2.get(0).get(1).get(0)).get(); - - assertEquals(expSmooth.getRow(0, true), arr0a); - assertEquals(expSmooth.getRow(1, true), arr1a); - } - - @Test - public void additionalTest(){ - /* - ## To reproduce: - from sklearn.feature_extraction.text import TfidfVectorizer - corpus = [ - 'This is the first document', - 'This document is the second document', - 'And this is the third one', - 'Is this the first document', - ] - vectorizer = TfidfVectorizer(min_df=0, norm=None, smooth_idf=False) - X = vectorizer.fit_transform(corpus) - print(vectorizer.get_feature_names()) - - out = vectorizer.transform(corpus) - print(out) - - ['and', 'document', 'first', 'is', 'one', 'second', 'the', 'third', 'this'] - (0, 8) 1.0 - (0, 6) 1.0 - (0, 3) 1.0 - (0, 2) 1.6931471805599454 - (0, 1) 1.2876820724517808 - (1, 8) 1.0 - (1, 6) 1.0 - (1, 5) 2.386294361119891 - (1, 3) 1.0 - (1, 1) 2.5753641449035616 - (2, 8) 1.0 - (2, 7) 2.386294361119891 - (2, 6) 1.0 - (2, 4) 2.386294361119891 - (2, 3) 1.0 - (2, 0) 2.386294361119891 - (3, 8) 1.0 - (3, 6) 1.0 - (3, 3) 1.0 - (3, 2) 1.6931471805599454 - (3, 1) 1.2876820724517808 - {'and': 2.386294361119891, 'document': 1.2876820724517808, 'first': 1.6931471805599454, 'is': 1.0, 'one': 2.386294361119891, 'second': 2.386294361119891, 'the': 1.0, 'third': 2.386294361119891, 'this': 1.0} - */ - - String[] corpus = { - "This is the first document", - "This document is the second document", - "And this is the third one", - "Is this the first document"}; - - TfidfVectorizer tfidfVectorizer = new TfidfVectorizer(); - Configuration configuration = new Configuration(); - configuration.set(TOKENIZER, DefaultTokenizerFactory.class.getName()); - configuration.set(MIN_WORD_FREQUENCY,"1"); - configuration.set(STOP_WORDS,""); - configuration.set(TfidfVectorizer.SMOOTH_IDF, "false"); - configuration.set(PREPROCESSOR, LowerCasePreProcessor.class.getName()); - - tfidfVectorizer.initialize(configuration); - - List>> input = new ArrayList<>(); - //input.add(Arrays.asList(Arrays.asList(new Text(corpus[0])),Arrays.asList(new Text(corpus[1])))); - List> seq = new ArrayList<>(); - for(String s : corpus){ - seq.add(Collections.singletonList(new Text(s))); - } - input.add(seq); - - CollectionRecordReader crr = new CollectionRecordReader(seq); - INDArray arr = tfidfVectorizer.fitTransform(crr); - - //System.out.println(arr); - assertArrayEquals(new long[]{4, 9}, arr.shape()); - - List pyVocab = Arrays.asList("and", "document", "first", "is", "one", "second", "the", "third", "this"); - List> l = new ArrayList<>(); - l.add(new Triple<>(0, 8, 1.0)); - l.add(new Triple<>(0, 6, 1.0)); - l.add(new Triple<>(0, 3, 1.0)); - l.add(new Triple<>(0, 2, 1.6931471805599454)); - l.add(new Triple<>(0, 1, 1.2876820724517808)); - l.add(new Triple<>(1, 8, 1.0)); - l.add(new Triple<>(1, 6, 1.0)); - l.add(new Triple<>(1, 5, 2.386294361119891)); - l.add(new Triple<>(1, 3, 1.0)); - l.add(new Triple<>(1, 1, 2.5753641449035616)); - l.add(new Triple<>(2, 8, 1.0)); - l.add(new Triple<>(2, 7, 2.386294361119891)); - l.add(new Triple<>(2, 6, 1.0)); - l.add(new Triple<>(2, 4, 2.386294361119891)); - l.add(new Triple<>(2, 3, 1.0)); - l.add(new Triple<>(2, 0, 2.386294361119891)); - l.add(new Triple<>(3, 8, 1.0)); - l.add(new Triple<>(3, 6, 1.0)); - l.add(new Triple<>(3, 3, 1.0)); - l.add(new Triple<>(3, 2, 1.6931471805599454)); - l.add(new Triple<>(3, 1, 1.2876820724517808)); - - INDArray exp = Nd4j.create(DataType.FLOAT, 4, 9); - for(Triple t : l){ - //Work out work index, accounting for different vocab/word orders: - int wIdx = tfidfVectorizer.getCache().wordIndex(pyVocab.get(t.getSecond())); - exp.putScalar(t.getFirst(), wIdx, t.getThird()); - } - - assertEquals(exp, arr); - - - Map idfWeights = new HashMap<>(); - idfWeights.put("and", 2.386294361119891); - idfWeights.put("document", 1.2876820724517808); - idfWeights.put("first", 1.6931471805599454); - idfWeights.put("is", 1.0); - idfWeights.put("one", 2.386294361119891); - idfWeights.put("second", 2.386294361119891); - idfWeights.put("the", 1.0); - idfWeights.put("third", 2.386294361119891); - idfWeights.put("this", 1.0); - - - List vocab = new ArrayList<>(9); //Arrays.asList("is","nice","strange","this","very"); - for( int i=0; i<9; i++ ){ - vocab.add(tfidfVectorizer.getCache().wordAt(i)); - } - - String inputColumnName = "input"; - String outputColumnName = "output"; - Map wordIndexMap = new HashMap<>(); - for(int i = 0; i < vocab.size(); i++) { - wordIndexMap.put(vocab.get(i),i); - } - - TokenizerBagOfWordsTermSequenceIndexTransform transform = new TokenizerBagOfWordsTermSequenceIndexTransform( - inputColumnName, - outputColumnName, - wordIndexMap, - idfWeights, - false, - null, LowerCasePreProcessor.class.getName()); - - SequenceSchema.Builder sequenceSchemaBuilder = new SequenceSchema.Builder(); - sequenceSchemaBuilder.addColumnString("input"); - SequenceSchema schema = sequenceSchemaBuilder.build(); - assertEquals("input",schema.getName(0)); - - TransformProcess transformProcess = new TransformProcess.Builder(schema) - .transform(transform) - .build(); - - List>> execute = LocalTransformExecutor.executeSequenceToSequence(input, transformProcess); - - INDArray arr0 = ((NDArrayWritable)execute.get(0).get(0).get(0)).get(); - INDArray arr1 = ((NDArrayWritable)execute.get(0).get(1).get(0)).get(); - - assertEquals(exp.getRow(0, true), arr0); - assertEquals(exp.getRow(1, true), arr1); - } - -} diff --git a/datavec/datavec-data/datavec-data-nlp/src/test/resources/logback.xml b/datavec/datavec-data/datavec-data-nlp/src/test/resources/logback.xml deleted file mode 100644 index abb9912c7..000000000 --- a/datavec/datavec-data/datavec-data-nlp/src/test/resources/logback.xml +++ /dev/null @@ -1,53 +0,0 @@ - - - - - - logs/application.log - - %date - [%level] - from %logger in %thread - %n%message%n%xException%n - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/datavec/datavec-data/datavec-geo/pom.xml b/datavec/datavec-data/datavec-geo/pom.xml deleted file mode 100644 index b19518faa..000000000 --- a/datavec/datavec-data/datavec-geo/pom.xml +++ /dev/null @@ -1,56 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-data - 1.0.0-SNAPSHOT - - - datavec-geo - - - - org.datavec - datavec-api - - - com.maxmind.geoip2 - geoip2 - ${geoip2.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java deleted file mode 100644 index a1ae236d7..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.geo; - -public enum LocationType { - CITY, CITY_ID, CONTINENT, CONTINENT_ID, COUNTRY, COUNTRY_ID, COORDINATES, POSTAL_CODE, SUBDIVISIONS, SUBDIVISIONS_ID -} diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java deleted file mode 100644 index 50459850c..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java +++ /dev/null @@ -1,194 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.reduce.geo; - -import lombok.Getter; -import org.datavec.api.transform.ReduceOp; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.StringMetaData; -import org.datavec.api.transform.ops.IAggregableReduceOp; -import org.datavec.api.transform.reduce.AggregableColumnReduction; -import org.datavec.api.transform.reduce.AggregableReductionUtils; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.nd4j.common.function.Supplier; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -public class CoordinatesReduction implements AggregableColumnReduction { - public static final String DEFAULT_COLUMN_NAME = "CoordinatesReduction"; - - public final static String DEFAULT_DELIMITER = ":"; - protected String delimiter = DEFAULT_DELIMITER; - - private final List columnNamesPostReduce; - - private final Supplier>> multiOp(final List ops) { - return new Supplier>>() { - @Override - public IAggregableReduceOp> get() { - return AggregableReductionUtils.reduceDoubleColumn(ops, false, null); - } - }; - } - - public CoordinatesReduction(String columnNamePostReduce, ReduceOp op) { - this(columnNamePostReduce, op, DEFAULT_DELIMITER); - } - - public CoordinatesReduction(List columnNamePostReduce, List op) { - this(columnNamePostReduce, op, DEFAULT_DELIMITER); - } - - public CoordinatesReduction(String columnNamePostReduce, ReduceOp op, String delimiter) { - this(Collections.singletonList(columnNamePostReduce), Collections.singletonList(op), delimiter); - } - - public CoordinatesReduction(List columnNamesPostReduce, List ops, String delimiter) { - this.columnNamesPostReduce = columnNamesPostReduce; - this.reducer = new CoordinateAggregableReduceOp(ops.size(), multiOp(ops), delimiter); - } - - @Override - public List getColumnsOutputName(String columnInputName) { - return columnNamesPostReduce; - } - - @Override - public List getColumnOutputMetaData(List newColumnName, ColumnMetaData columnInputMeta) { - List res = new ArrayList<>(newColumnName.size()); - for (String cn : newColumnName) - res.add(new StringMetaData((cn))); - return res; - } - - @Override - public Schema transform(Schema inputSchema) { - throw new UnsupportedOperationException(); - } - - @Override - public void setInputSchema(Schema inputSchema) { - throw new UnsupportedOperationException(); - } - - @Override - public Schema getInputSchema() { - throw new UnsupportedOperationException(); - } - - @Override - public String outputColumnName() { - throw new UnsupportedOperationException(); - } - - @Override - public String[] outputColumnNames() { - throw new UnsupportedOperationException(); - } - - @Override - public String[] columnNames() { - throw new UnsupportedOperationException(); - } - - @Override - public String columnName() { - throw new UnsupportedOperationException(); - } - - private IAggregableReduceOp> reducer; - - @Override - public IAggregableReduceOp> reduceOp() { - return reducer; - } - - - public static class CoordinateAggregableReduceOp implements IAggregableReduceOp> { - - - private int nOps; - private Supplier>> initialOpValue; - @Getter - private ArrayList>> perCoordinateOps; // of size coords() - private String delimiter; - - public CoordinateAggregableReduceOp(int n, Supplier>> initialOp, - String delim) { - this.nOps = n; - this.perCoordinateOps = new ArrayList<>(); - this.initialOpValue = initialOp; - this.delimiter = delim; - } - - @Override - public >> void combine(W accu) { - if (accu instanceof CoordinateAggregableReduceOp) { - CoordinateAggregableReduceOp accumulator = (CoordinateAggregableReduceOp) accu; - for (int i = 0; i < Math.min(perCoordinateOps.size(), accumulator.getPerCoordinateOps().size()); i++) { - perCoordinateOps.get(i).combine(accumulator.getPerCoordinateOps().get(i)); - } // the rest is assumed identical - } - } - - @Override - public void accept(Writable writable) { - String[] coordinates = writable.toString().split(delimiter); - for (int i = 0; i < coordinates.length; i++) { - String coordinate = coordinates[i]; - while (perCoordinateOps.size() < i + 1) { - perCoordinateOps.add(initialOpValue.get()); - } - perCoordinateOps.get(i).accept(new DoubleWritable(Double.parseDouble(coordinate))); - } - } - - @Override - public List get() { - List res = new ArrayList<>(nOps); - for (int i = 0; i < nOps; i++) { - res.add(new StringBuilder()); - } - - for (int i = 0; i < perCoordinateOps.size(); i++) { - List resThisCoord = perCoordinateOps.get(i).get(); - for (int j = 0; j < nOps; j++) { - res.get(j).append(resThisCoord.get(j).toString()); - if (i < perCoordinateOps.size() - 1) { - res.get(j).append(delimiter); - } - } - } - - List finalRes = new ArrayList<>(nOps); - for (StringBuilder sb : res) { - finalRes.add(new Text(sb.toString())); - } - return finalRes; - } - } - -} diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java deleted file mode 100644 index dacd09222..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.transform.geo; - -import org.datavec.api.transform.MathOp; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.DoubleMetaData; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.BaseColumnsMathOpTransform; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Writable; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class CoordinatesDistanceTransform extends BaseColumnsMathOpTransform { - - public final static String DEFAULT_DELIMITER = ":"; - protected String delimiter = DEFAULT_DELIMITER; - - public CoordinatesDistanceTransform(String newColumnName, String firstColumn, String secondColumn, - String stdevColumn) { - this(newColumnName, firstColumn, secondColumn, stdevColumn, DEFAULT_DELIMITER); - } - - public CoordinatesDistanceTransform(@JsonProperty("newColumnName") String newColumnName, - @JsonProperty("firstColumn") String firstColumn, @JsonProperty("secondColumn") String secondColumn, - @JsonProperty("stdevColumn") String stdevColumn, @JsonProperty("delimiter") String delimiter) { - super(newColumnName, MathOp.Add /* dummy op */, - stdevColumn != null ? new String[] {firstColumn, secondColumn, stdevColumn} - : new String[] {firstColumn, secondColumn}); - this.delimiter = delimiter; - } - - @Override - protected ColumnMetaData derivedColumnMetaData(String newColumnName, Schema inputSchema) { - return new DoubleMetaData(newColumnName); - } - - @Override - protected Writable doOp(Writable... input) { - String[] first = input[0].toString().split(delimiter); - String[] second = input[1].toString().split(delimiter); - String[] stdev = columns.length > 2 ? input[2].toString().split(delimiter) : null; - - double dist = 0; - for (int i = 0; i < first.length; i++) { - double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); - double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; - dist += (d * d) / (s * s); - } - return new DoubleWritable(Math.sqrt(dist)); - } - - @Override - public String toString() { - return "CoordinatesDistanceTransform(newColumnName=\"" + newColumnName + "\",columns=" - + Arrays.toString(columns) + ",delimiter=" + delimiter + ")"; - } - - /** - * Transform an object - * in to another object - * - * @param input the record to transform - * @return the transformed writable - */ - @Override - public Object map(Object input) { - List row = (List) input; - String[] first = row.get(0).toString().split(delimiter); - String[] second = row.get(1).toString().split(delimiter); - String[] stdev = columns.length > 2 ? row.get(2).toString().split(delimiter) : null; - - double dist = 0; - for (int i = 0; i < first.length; i++) { - double d = Double.parseDouble(first[i]) - Double.parseDouble(second[i]); - double s = stdev != null ? Double.parseDouble(stdev[i]) : 1; - dist += (d * d) / (s * s); - } - return Math.sqrt(dist); - } - - /** - * Transform a sequence - * - * @param sequence - */ - @Override - public Object mapSequence(Object sequence) { - List seq = (List) sequence; - List ret = new ArrayList<>(); - for (Object step : seq) - ret.add((Double) map(step)); - return ret; - } -} diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java deleted file mode 100644 index 47399d661..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.transform.geo; - -import org.apache.commons.io.FileUtils; -import org.nd4j.common.base.Preconditions; -import org.nd4j.common.util.ArchiveUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.IOException; -import java.net.URL; - -public class GeoIPFetcher { - protected static final Logger log = LoggerFactory.getLogger(GeoIPFetcher.class); - - /** Default directory for http://dev.maxmind.com/geoip/geoipupdate/ */ - public static final String GEOIP_DIR = "/usr/local/share/GeoIP/"; - public static final String GEOIP_DIR2 = System.getProperty("user.home") + "/.datavec-geoip"; - - public static final String CITY_DB = "GeoIP2-City.mmdb"; - public static final String CITY_LITE_DB = "GeoLite2-City.mmdb"; - - public static final String CITY_LITE_URL = - "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz"; - - public static synchronized File fetchCityDB() throws IOException { - File cityFile = new File(GEOIP_DIR, CITY_DB); - if (cityFile.isFile()) { - return cityFile; - } - cityFile = new File(GEOIP_DIR, CITY_LITE_DB); - if (cityFile.isFile()) { - return cityFile; - } - cityFile = new File(GEOIP_DIR2, CITY_LITE_DB); - if (cityFile.isFile()) { - return cityFile; - } - - log.info("Downloading GeoLite2 City database..."); - File archive = new File(GEOIP_DIR2, CITY_LITE_DB + ".gz"); - File dir = new File(GEOIP_DIR2); - dir.mkdirs(); - FileUtils.copyURLToFile(new URL(CITY_LITE_URL), archive); - ArchiveUtils.unzipFileTo(archive.getAbsolutePath(), dir.getAbsolutePath()); - Preconditions.checkState(cityFile.isFile(), "Error extracting files: expected city file does not exist after extraction"); - - return cityFile; - } -} diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java deleted file mode 100644 index 47c13f50e..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.transform.geo; - -import org.datavec.api.transform.geo.LocationType; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.io.IOException; - -public class IPAddressToCoordinatesTransform extends IPAddressToLocationTransform { - - public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName) throws IOException { - this(columnName, DEFAULT_DELIMITER); - } - - public IPAddressToCoordinatesTransform(@JsonProperty("columnName") String columnName, - @JsonProperty("delimiter") String delimiter) throws IOException { - super(columnName, LocationType.COORDINATES, delimiter); - } - - @Override - public String toString() { - return "IPAddressToCoordinatesTransform"; - } -} diff --git a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java deleted file mode 100644 index 807df1ffe..000000000 --- a/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java +++ /dev/null @@ -1,184 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.transform.geo; - -import com.maxmind.geoip2.DatabaseReader; -import com.maxmind.geoip2.exception.GeoIp2Exception; -import com.maxmind.geoip2.model.CityResponse; -import com.maxmind.geoip2.record.Location; -import com.maxmind.geoip2.record.Subdivision; -import lombok.extern.slf4j.Slf4j; -import org.datavec.api.transform.geo.LocationType; -import org.datavec.api.transform.metadata.ColumnMetaData; -import org.datavec.api.transform.metadata.StringMetaData; -import org.datavec.api.transform.transform.BaseColumnTransform; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.nd4j.shade.jackson.annotation.JsonProperty; - -import java.io.File; -import java.io.IOException; -import java.io.ObjectInputStream; -import java.io.ObjectOutputStream; -import java.net.InetAddress; - -@Slf4j -public class IPAddressToLocationTransform extends BaseColumnTransform { - /** - * Name of the system property to use when configuring the GeoIP database file.
- * Most users don't need to set this - typically used for testing purposes.
- * Set with the full local path, like: "C:/datavec-geo/GeoIP2-City-Test.mmdb" - */ - public static final String GEOIP_FILE_PROPERTY = "org.datavec.geoip.file"; - - private static File database; - private static DatabaseReader reader; - - public final static String DEFAULT_DELIMITER = ":"; - protected String delimiter = DEFAULT_DELIMITER; - protected LocationType locationType; - - private static synchronized void init() throws IOException { - // A File object pointing to your GeoIP2 or GeoLite2 database: - // http://dev.maxmind.com/geoip/geoip2/geolite2/ - if (database == null) { - String s = System.getProperty(GEOIP_FILE_PROPERTY); - if(s != null && !s.isEmpty()){ - //Use user-specified GEOIP file - mainly for testing purposes - File f = new File(s); - if(f.exists() && f.isFile()){ - database = f; - } else { - log.warn("GeoIP file (system property {}) is set to \"{}\" but this is not a valid file, using default database", GEOIP_FILE_PROPERTY, s); - database = GeoIPFetcher.fetchCityDB(); - } - } else { - database = GeoIPFetcher.fetchCityDB(); - } - } - - // This creates the DatabaseReader object, which should be reused across lookups. - if (reader == null) { - reader = new DatabaseReader.Builder(database).build(); - } - } - - public IPAddressToLocationTransform(String columnName) throws IOException { - this(columnName, LocationType.CITY); - } - - public IPAddressToLocationTransform(String columnName, LocationType locationType) throws IOException { - this(columnName, locationType, DEFAULT_DELIMITER); - } - - public IPAddressToLocationTransform(@JsonProperty("columnName") String columnName, - @JsonProperty("delimiter") LocationType locationType, @JsonProperty("delimiter") String delimiter) - throws IOException { - super(columnName); - this.delimiter = delimiter; - this.locationType = locationType; - init(); - } - - @Override - public ColumnMetaData getNewColumnMetaData(String newName, ColumnMetaData oldColumnType) { - return new StringMetaData(newName); //Output after transform: String (Text) - } - - @Override - public Writable map(Writable columnWritable) { - try { - InetAddress ipAddress = InetAddress.getByName(columnWritable.toString()); - CityResponse response = reader.city(ipAddress); - String text = ""; - switch (locationType) { - case CITY: - text = response.getCity().getName(); - break; - case CITY_ID: - text = response.getCity().getGeoNameId().toString(); - break; - case CONTINENT: - text = response.getContinent().getName(); - break; - case CONTINENT_ID: - text = response.getContinent().getGeoNameId().toString(); - break; - case COUNTRY: - text = response.getCountry().getName(); - break; - case COUNTRY_ID: - text = response.getCountry().getGeoNameId().toString(); - break; - case COORDINATES: - Location location = response.getLocation(); - text = location.getLatitude() + delimiter + location.getLongitude(); - break; - case POSTAL_CODE: - text = response.getPostal().getCode(); - break; - case SUBDIVISIONS: - for (Subdivision s : response.getSubdivisions()) { - if (text.length() > 0) { - text += delimiter; - } - text += s.getName(); - } - break; - case SUBDIVISIONS_ID: - for (Subdivision s : response.getSubdivisions()) { - if (text.length() > 0) { - text += delimiter; - } - text += s.getGeoNameId().toString(); - } - break; - default: - assert false; - } - if(text == null) - text = ""; - return new Text(text); - } catch (GeoIp2Exception | IOException e) { - throw new RuntimeException(e); - } - } - - @Override - public String toString() { - return "IPAddressToLocationTransform"; - } - - //Custom serialization methods, because GeoIP2 doesn't allow DatabaseReader objects to be serialized :( - private void writeObject(ObjectOutputStream out) throws IOException { - out.defaultWriteObject(); - } - - private void readObject(ObjectInputStream in) throws IOException, ClassNotFoundException { - in.defaultReadObject(); - init(); - } - - @Override - public Object map(Object input) { - return null; - } -} diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java deleted file mode 100644 index 9423e525b..000000000 --- a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.api.transform; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.api.transform"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java deleted file mode 100644 index e5422a46d..000000000 --- a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.reduce; - -import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.ReduceOp; -import org.datavec.api.transform.ops.IAggregableReduceOp; -import org.datavec.api.transform.reduce.geo.CoordinatesReduction; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.junit.Test; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * @author saudet - */ -public class TestGeoReduction { - - @Test - public void testCustomReductions() { - - List> inputs = new ArrayList<>(); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("1#5"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("2#6"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("3#7"))); - inputs.add(Arrays.asList((Writable) new Text("someKey"), new Text("4#8"))); - - List expected = Arrays.asList((Writable) new Text("someKey"), new Text("10.0#26.0")); - - Schema schema = new Schema.Builder().addColumnString("key").addColumnString("coord").build(); - - Reducer reducer = new Reducer.Builder(ReduceOp.Count).keyColumns("key") - .customReduction("coord", new CoordinatesReduction("coordSum", ReduceOp.Sum, "#")).build(); - - reducer.setInputSchema(schema); - - IAggregableReduceOp, List> aggregableReduceOp = reducer.aggregableReducer(); - for (List l : inputs) - aggregableReduceOp.accept(l); - List out = aggregableReduceOp.get(); - - assertEquals(2, out.size()); - assertEquals(expected, out); - - //Check schema: - String[] expNames = new String[] {"key", "coordSum"}; - ColumnType[] expTypes = new ColumnType[] {ColumnType.String, ColumnType.String}; - Schema outSchema = reducer.transform(schema); - - assertEquals(2, outSchema.numColumns()); - for (int i = 0; i < 2; i++) { - assertEquals(expNames[i], outSchema.getName(i)); - assertEquals(expTypes[i], outSchema.getType(i)); - } - } -} diff --git a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java deleted file mode 100644 index 349e04cc1..000000000 --- a/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java +++ /dev/null @@ -1,153 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.api.transform.transform; - -import org.datavec.api.transform.ColumnType; -import org.datavec.api.transform.Transform; -import org.datavec.api.transform.geo.LocationType; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.geo.CoordinatesDistanceTransform; -import org.datavec.api.transform.transform.geo.IPAddressToCoordinatesTransform; -import org.datavec.api.transform.transform.geo.IPAddressToLocationTransform; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.common.io.ClassPathResource; - -import java.io.*; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * @author saudet - */ -public class TestGeoTransforms { - - @BeforeClass - public static void beforeClass() throws Exception { - //Use test resources version to avoid tests suddenly failing due to IP/Location DB content changing - File f = new ClassPathResource("datavec-geo/GeoIP2-City-Test.mmdb").getFile(); - System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, f.getPath()); - } - - @AfterClass - public static void afterClass(){ - System.setProperty(IPAddressToLocationTransform.GEOIP_FILE_PROPERTY, ""); - } - - @Test - public void testCoordinatesDistanceTransform() throws Exception { - Schema schema = new Schema.Builder().addColumnString("point").addColumnString("mean").addColumnString("stddev") - .build(); - - Transform transform = new CoordinatesDistanceTransform("dist", "point", "mean", "stddev", "\\|"); - transform.setInputSchema(schema); - - Schema out = transform.transform(schema); - assertEquals(4, out.numColumns()); - assertEquals(Arrays.asList("point", "mean", "stddev", "dist"), out.getColumnNames()); - assertEquals(Arrays.asList(ColumnType.String, ColumnType.String, ColumnType.String, ColumnType.Double), - out.getColumnTypes()); - - assertEquals(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10"), new DoubleWritable(5.0)), - transform.map(Arrays.asList((Writable) new Text("-30"), new Text("20"), new Text("10")))); - assertEquals(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), new Text("10|5"), - new DoubleWritable(Math.sqrt(160))), - transform.map(Arrays.asList((Writable) new Text("50|40"), new Text("10|-20"), - new Text("10|5")))); - } - - @Test - public void testIPAddressToCoordinatesTransform() throws Exception { - Schema schema = new Schema.Builder().addColumnString("column").build(); - - Transform transform = new IPAddressToCoordinatesTransform("column", "CUSTOM_DELIMITER"); - transform.setInputSchema(schema); - - Schema out = transform.transform(schema); - - assertEquals(1, out.getColumnMetaData().size()); - assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); - - String in = "81.2.69.160"; - double latitude = 51.5142; - double longitude = -0.0931; - - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); - assertEquals(1, writables.size()); - String[] coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); - assertEquals(2, coordinates.length); - assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); - assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); - - //Check serialization: things like DatabaseReader etc aren't serializable, hence we need custom serialization :/ - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ObjectOutputStream oos = new ObjectOutputStream(baos); - oos.writeObject(transform); - - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - ObjectInputStream ois = new ObjectInputStream(bais); - - Transform deserialized = (Transform) ois.readObject(); - writables = deserialized.map(Collections.singletonList((Writable) new Text(in))); - assertEquals(1, writables.size()); - coordinates = writables.get(0).toString().split("CUSTOM_DELIMITER"); - //System.out.println(Arrays.toString(coordinates)); - assertEquals(2, coordinates.length); - assertEquals(latitude, Double.parseDouble(coordinates[0]), 0.1); - assertEquals(longitude, Double.parseDouble(coordinates[1]), 0.1); - } - - @Test - public void testIPAddressToLocationTransform() throws Exception { - Schema schema = new Schema.Builder().addColumnString("column").build(); - LocationType[] locationTypes = LocationType.values(); - String in = "81.2.69.160"; - String[] locations = {"London", "2643743", "Europe", "6255148", "United Kingdom", "2635167", - "51.5142:-0.0931", "", "England", "6269131"}; //Note: no postcode in this test DB for this record - - for (int i = 0; i < locationTypes.length; i++) { - LocationType locationType = locationTypes[i]; - String location = locations[i]; - - Transform transform = new IPAddressToLocationTransform("column", locationType); - transform.setInputSchema(schema); - - Schema out = transform.transform(schema); - - assertEquals(1, out.getColumnMetaData().size()); - assertEquals(ColumnType.String, out.getMetaData(0).getColumnType()); - - List writables = transform.map(Collections.singletonList((Writable) new Text(in))); - assertEquals(1, writables.size()); - assertEquals(location, writables.get(0).toString()); - //System.out.println(location); - } - } -} diff --git a/datavec/datavec-data/pom.xml b/datavec/datavec-data/pom.xml index 233b85f9a..d5bfd6d05 100644 --- a/datavec/datavec-data/pom.xml +++ b/datavec/datavec-data/pom.xml @@ -37,11 +37,7 @@ datavec-data - datavec-data-audio - datavec-data-codec datavec-data-image - datavec-data-nlp - datavec-geo diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml deleted file mode 100644 index c69e1abcb..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-client - - datavec-spark-inference-client - - - - org.datavec - datavec-spark-inference-server_2.11 - 1.0.0-SNAPSHOT - test - - - org.datavec - datavec-spark-inference-model - ${project.parent.version} - - - com.mashape.unirest - unirest-java - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java deleted file mode 100644 index 8a346b096..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java +++ /dev/null @@ -1,292 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.client; - - -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.http.exceptions.UnirestException; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.*; -import org.datavec.spark.inference.model.service.DataVecTransformService; -import org.nd4j.shade.jackson.core.JsonProcessingException; - -import java.io.IOException; - -@AllArgsConstructor -@Slf4j -public class DataVecTransformClient implements DataVecTransformService { - private String url; - - static { - // Only one time - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - }); - } - - /** - * @param transformProcess - */ - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - try { - String s = transformProcess.toJson(); - Unirest.post(url + "/transformprocess").header("accept", "application/json") - .header("Content-Type", "application/json").body(s).asJson(); - - } catch (UnirestException e) { - log.error("Error in setCSVTransformProcess()", e); - } - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @return - */ - @Override - public TransformProcess getCSVTransformProcess() { - try { - String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") - .header("Content-Type", "application/json").asString().getBody(); - return TransformProcess.fromJson(s); - } catch (UnirestException e) { - log.error("Error in getCSVTransformProcess()",e); - } - - return null; - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @param transform - * @return - */ - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { - try { - SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .body(transform).asObject(SingleCSVRecord.class).getBody(); - return singleCsvRecord; - } catch (UnirestException e) { - log.error("Error in transformIncremental(SingleCSVRecord)",e); - } - return null; - } - - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - try { - SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"TRUE") - .body(batchCSVRecord) - .asObject(SequenceBatchCSVRecord.class) - .getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("",e); - } - - return null; - } - /** - * @param batchCSVRecord - * @return - */ - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - try { - BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"FALSE") - .body(batchCSVRecord) - .asObject(BatchCSVRecord.class) - .getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("Error in transform(BatchCSVRecord)", e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - try { - Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batchCSVRecord) - .asObject(Base64NDArrayBody.class).getBody(); - return batchArray1; - } catch (UnirestException e) { - log.error("Error in transformArray(BatchCSVRecord)",e); - } - - return null; - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - try { - Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); - return array; - } catch (UnirestException e) { - log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); - } - - return null; - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - try { - Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); - return array; - } catch (UnirestException e) { - log.error("Error in transformSequenceArrayIncremental",e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - try { - Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(batchCSVRecord) - .asObject(Base64NDArrayBody.class).getBody(); - return batchArray1; - } catch (UnirestException e) { - log.error("Error in transformSequenceArray",e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - try { - SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(batchCSVRecord) - .asObject(SequenceBatchCSVRecord.class).getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("Error in transformSequence"); - } - - return null; - } - - /** - * @param transform - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - try { - SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); - return singleCsvRecord; - } catch (UnirestException e) { - log.error("Error in transformSequenceIncremental"); - } - return null; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java deleted file mode 100644 index de2970b27..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.transform.client; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.transform.client"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java deleted file mode 100644 index 6619ec443..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.transform.client; - -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.client.DataVecTransformClient; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.net.ServerSocket; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assume.assumeNotNull; - -public class DataVecTransformClientTest { - private static CSVSparkTransformServer server; - private static int port = getAvailablePort(); - private static DataVecTransformClient client; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void beforeClass() throws Exception { - FileUtils.write(fileSave, transformProcess.toJson()); - fileSave.deleteOnExit(); - server = new CSVSparkTransformServer(); - server.runMain(new String[] {"-dp", String.valueOf(port)}); - - client = new DataVecTransformClient("http://localhost:" + port); - client.setCSVTransformProcess(transformProcess); - } - - @AfterClass - public static void afterClass() throws Exception { - server.stop(); - } - - - @Test - public void testSequenceClient() { - SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); - List batchCSVRecordList = new ArrayList<>(); - for(int i = 0; i < 5; i++) { - batchCSVRecordList.add(batchCSVRecord); - } - - sequenceBatchCSVRecord.add(batchCSVRecordList); - - SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord); - assumeNotNull(sequenceBatchCSVRecord1); - - Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord); - assumeNotNull(array); - - Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord); - assumeNotNull(incrementalBody); - - Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord); - assumeNotNull(incrementalSequenceBody); - } - - @Test - public void testRecord() throws Exception { - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord); - assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size()); - Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord); - INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); - assumeNotNull(arr); - } - - @Test - public void testBatchRecord() throws Exception { - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); - BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord); - assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size()); - - Base64NDArrayBody body = client.transformArray(batchCSVRecord); - INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); - assumeNotNull(arr); - } - - - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf deleted file mode 100644 index dbac92d83..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf +++ /dev/null @@ -1,6 +0,0 @@ -play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule -play.modules.enabled += io.skymind.skil.service.PredictionModule -play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk -play.server.pidfile.path=/tmp/RUNNING_PID - -play.server.http.port = 9600 diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml deleted file mode 100644 index fe9ca985a..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ /dev/null @@ -1,63 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-model - - datavec-spark-inference-model - - - - org.datavec - datavec-api - ${datavec.version} - - - org.datavec - datavec-data-image - - - org.datavec - datavec-local - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java deleted file mode 100644 index e081708e0..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.FieldVector; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; -import org.datavec.arrow.ArrowConverter; -import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; -import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; -import org.datavec.local.transforms.LocalTransformExecutor; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import static org.datavec.arrow.ArrowConverter.*; -import static org.datavec.local.transforms.LocalTransformExecutor.execute; -import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence; - -@AllArgsConstructor -@Slf4j -public class CSVSparkTransform { - @Getter - private TransformProcess transformProcess; - private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); - - /** - * Convert a raw record via - * the {@link TransformProcess} - * to a base 64ed ndarray - * @param batch the record to convert - * @return teh base 64ed ndarray - * @throws IOException - */ - public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException { - List> converted = execute(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - batch.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - - ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted; - INDArray convert = ArrowConverter.toArray(arrowRecordBatch); - return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); - } - - /** - * Convert a raw record via - * the {@link TransformProcess} - * to a base 64ed ndarray - * @param record the record to convert - * @return the base 64ed ndarray - * @throws IOException - */ - public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException { - List record2 = toArrowWritablesSingle( - toArrowColumnsStringSingle(bufferAllocator, - transformProcess.getInitialSchema(),record.getValues()), - transformProcess.getInitialSchema()); - List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); - return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); - } - - /** - * Runs the transform process - * @param batch the record to transform - * @return the transformed record - */ - public BatchCSVRecord transform(BatchCSVRecord batch) { - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - List> converted = execute(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - batch.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - int numCols = converted.get(0).size(); - for (int row = 0; row < converted.size(); row++) { - String[] values = new String[numCols]; - for (int i = 0; i < values.length; i++) - values[i] = converted.get(row).get(i).toString(); - batchCSVRecord.add(new SingleCSVRecord(values)); - } - - return batchCSVRecord; - - } - - /** - * Runs the transform process - * @param record the record to transform - * @return the transformed record - */ - public SingleCSVRecord transform(SingleCSVRecord record) { - List record2 = toArrowWritablesSingle( - toArrowColumnsStringSingle(bufferAllocator, - transformProcess.getInitialSchema(),record.getValues()), - transformProcess.getInitialSchema()); - List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - String[] values = new String[finalRecord.size()]; - for (int i = 0; i < values.length; i++) - values[i] = finalRecord.get(i).toString(); - return new SingleCSVRecord(values); - - } - - /** - * - * @param transform - * @return - */ - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - /** - * Sequence schema? - */ - List>> converted = executeToSequence( - toArrowWritables(toArrowColumnsStringTimeSeries( - bufferAllocator, transformProcess.getInitialSchema(), - Arrays.asList(transform.getRecordsAsString())), - transformProcess.getInitialSchema()), transformProcess); - - SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); - for (int i = 0; i < converted.size(); i++) { - BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i)); - batchCSVRecord.add(Arrays.asList(batchCSVRecord1)); - } - - return batchCSVRecord; - } - - /** - * - * @param batchCSVRecordSequence - * @return - */ - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) { - List>> recordsAsString = batchCSVRecordSequence.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : recordsAsString) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors, - transformProcess.getInitialSchema(), - recordsAsString.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - - } - } - - /** - * TODO: optimize - * @param batchCSVRecordSequence - * @return - */ - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) { - List>> strings = batchCSVRecordSequence.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : strings) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - - } - - /** - * - * @param singleCsvRecord - * @return - */ - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - List>> converted = executeToSequence(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - singleCsvRecord.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted; - INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - log.error("",e); - } - - return null; - } - - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - List>> strings = batchCSVRecord.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : strings) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - - } - - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java deleted file mode 100644 index a004c439b..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import org.datavec.image.data.ImageWritable; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -@AllArgsConstructor -public class ImageSparkTransform { - @Getter - private ImageTransformProcess imageTransformProcess; - - public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException { - ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri()); - INDArray finalRecord = imageTransformProcess.executeArray(record2); - - return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord)); - } - - public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException { - List records = new ArrayList<>(); - - for (SingleImageRecord imgRecord : batch.getRecords()) { - ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri()); - INDArray finalRecord = imageTransformProcess.executeArray(record2); - records.add(finalRecord); - } - - INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()])); - - return new Base64NDArrayBody(Nd4jBase64.base64String(array)); - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java deleted file mode 100644 index 0d6c680ad..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class Base64NDArrayBody { - private String ndarray; -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java deleted file mode 100644 index 82ecedc51..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class BatchCSVRecord implements Serializable { - private List records; - - - /** - * Get the records as a list of strings - * (basically the underlying values for - * {@link SingleCSVRecord}) - * @return - */ - public List> getRecordsAsString() { - if(records == null) - records = new ArrayList<>(); - List> ret = new ArrayList<>(); - for(SingleCSVRecord csvRecord : records) { - ret.add(csvRecord.getValues()); - } - return ret; - } - - - /** - * Create a batch csv record - * from a list of writables. - * @param batch - * @return - */ - public static BatchCSVRecord fromWritables(List> batch) { - List records = new ArrayList<>(batch.size()); - for(List list : batch) { - List add = new ArrayList<>(list.size()); - for(Writable writable : list) { - add.add(writable.toString()); - } - records.add(new SingleCSVRecord(add)); - } - - return BatchCSVRecord.builder().records(records).build(); - } - - - /** - * Add a record - * @param record - */ - public void add(SingleCSVRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static BatchCSVRecord fromDataSet(DataSet dataSet) { - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < dataSet.numExamples(); i++) { - batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i))); - } - - return batchCSVRecord; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java deleted file mode 100644 index ff101c659..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.net.URI; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class BatchImageRecord { - private List records; - - /** - * Add a record - * @param record - */ - public void add(SingleImageRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - public void add(URI uri) { - this.add(new SingleImageRecord(uri)); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java deleted file mode 100644 index eed4fac59..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.MultiDataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class SequenceBatchCSVRecord implements Serializable { - private List> records; - - /** - * Add a record - * @param record - */ - public void add(List record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - /** - * Get the records as a list of strings directly - * (this basically "unpacks" the objects) - * @return - */ - public List>> getRecordsAsString() { - if(records == null) - Collections.emptyList(); - List>> ret = new ArrayList<>(records.size()); - for(List record : records) { - List> add = new ArrayList<>(); - for(BatchCSVRecord batchCSVRecord : record) { - for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) { - add.add(singleCSVRecord.getValues()); - } - } - - ret.add(add); - } - - return ret; - } - - /** - * Convert a writables time series to a sequence batch - * @param input - * @return - */ - public static SequenceBatchCSVRecord fromWritables(List>> input) { - SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord(); - for(int i = 0; i < input.size(); i++) { - ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i)))); - } - - return ret; - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) { - SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); - for (int i = 0; i < dataSet.numFeatureArrays(); i++) { - batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i))))); - } - - return batchCSVRecord; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java deleted file mode 100644 index 575a91918..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class SingleCSVRecord implements Serializable { - private List values; - - /** - * Create from an array of values uses list internally) - * @param values - */ - public SingleCSVRecord(String...values) { - this.values = Arrays.asList(values); - } - - /** - * Instantiate a csv record from a vector - * given either an input dataset and a - * one hot matrix, the index will be appended to - * the end of the record, or for regression - * it will append all values in the labels - * @param row the input vectors - * @return the record from this {@link DataSet} - */ - public static SingleCSVRecord fromRow(DataSet row) { - if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) - throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); - if (!row.getLabels().isVector() && !row.getLabels().isScalar()) - throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); - //classification - SingleCSVRecord record; - int idx = 0; - if (row.getLabels().sumNumber().doubleValue() == 1.0) { - String[] values = new String[row.getFeatures().columns() + 1]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - int maxIdx = 0; - for (int i = 0; i < row.getLabels().length(); i++) { - if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { - maxIdx = i; - } - } - - values[idx++] = String.valueOf(maxIdx); - record = new SingleCSVRecord(values); - } - //regression (any number of values) - else { - String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - for (int i = 0; i < row.getLabels().length(); i++) { - values[idx++] = String.valueOf(row.getLabels().getDouble(i)); - } - - - record = new SingleCSVRecord(values); - - } - return record; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java deleted file mode 100644 index 9fe3df042..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.net.URI; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class SingleImageRecord { - private URI uri; -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java deleted file mode 100644 index c23dd562c..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.service; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.*; - -import java.io.IOException; - -public interface DataVecTransformService { - - String SEQUENCE_OR_NOT_HEADER = "Sequence"; - - - /** - * - * @param transformProcess - */ - void setCSVTransformProcess(TransformProcess transformProcess); - - /** - * - * @param imageTransformProcess - */ - void setImageTransformProcess(ImageTransformProcess imageTransformProcess); - - /** - * - * @return - */ - TransformProcess getCSVTransformProcess(); - - /** - * - * @return - */ - ImageTransformProcess getImageTransformProcess(); - - /** - * - * @param singleCsvRecord - * @return - */ - SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord); - - SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - BatchCSVRecord transform(BatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord); - - /** - * - * @param singleCsvRecord - * @return - */ - Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord); - - /** - * - * @param singleImageRecord - * @return - * @throws IOException - */ - Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException; - - /** - * - * @param batchImageRecord - * @return - * @throws IOException - */ - Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException; - - /** - * - * @param singleCsvRecord - * @return - */ - Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param transform - * @return - */ - SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform); -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java deleted file mode 100644 index ab76b206e..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.spark.transform; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.spark.transform"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java deleted file mode 100644 index a5ce6c474..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class BatchCSVRecordTest { - - @Test - public void testBatchRecordCreationFromDataSet() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); - - BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet); - assertEquals(2, batchCSVRecord.getRecords().size()); - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java deleted file mode 100644 index 7d1fe5f3b..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.integer.BaseIntegerTransform; -import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.spark.inference.model.CSVSparkTransform; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.util.*; - -import static org.junit.Assert.*; - -public class CSVSparkTransformTest { - @Test - public void testTransformer() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - assertTrue(fromBase64.isVector()); -// System.out.println("Base 64ed array " + fromBase64); - } - - @Test - public void testTransformerBatch() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(record); - //data type is string, unable to convert - BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); - /* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1); - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - assertTrue(fromBase64.isMatrix()); - System.out.println("Base 64ed array " + fromBase64); */ - } - - - - @Test - public void testSingleBatchSequence() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(record); - BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); - SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); - sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); - Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); - INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray()); - - - //ensure accumulation - sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); - sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); - assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape()); - - SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); - assertNotNull(transformed.getRecords()); -// System.out.println(transformed); - - - } - - @Test - public void testSpecificSequence() throws Exception { - final Schema schema = new Schema.Builder() - .addColumnsString("action") - .build(); - - final TransformProcess transformProcess = new TransformProcess.Builder(schema) - .removeAllColumnsExceptFor("action") - .transform(new ConverToLowercase("action")) - .convertToSequence() - .transform(new TextToCharacterIndexTransform("action", "action_sequence", - defaultCharIndex(), false)) - .integerToOneHot("action_sequence",0,29) - .build(); - - final String[] data1 = new String[] { "test1" }; - final String[] data2 = new String[] { "test2" }; - final BatchCSVRecord batchCsvRecord = new BatchCSVRecord( - Arrays.asList( - new SingleCSVRecord(data1), - new SingleCSVRecord(data2))); - - final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); -// System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); - transform.transformSequenceIncremental(batchCsvRecord); - assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); - - } - - private static Map defaultCharIndex() { - Map ret = new TreeMap<>(); - - ret.put('a',0); - ret.put('b',1); - ret.put('c',2); - ret.put('d',3); - ret.put('e',4); - ret.put('f',5); - ret.put('g',6); - ret.put('h',7); - ret.put('i',8); - ret.put('j',9); - ret.put('k',10); - ret.put('l',11); - ret.put('m',12); - ret.put('n',13); - ret.put('o',14); - ret.put('p',15); - ret.put('q',16); - ret.put('r',17); - ret.put('s',18); - ret.put('t',19); - ret.put('u',20); - ret.put('v',21); - ret.put('w',22); - ret.put('x',23); - ret.put('y',24); - ret.put('z',25); - ret.put('/',26); - ret.put(' ',27); - ret.put('(',28); - ret.put(')',29); - - return ret; - } - - public static class ConverToLowercase extends BaseIntegerTransform { - public ConverToLowercase(String column) { - super(column); - } - - public Text map(Writable writable) { - return new Text(writable.toString().toLowerCase()); - } - - public Object map(Object input) { - return new Text(input.toString().toLowerCase()); - } - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java deleted file mode 100644 index 415730b18..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.ImageSparkTransform; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; - -import static org.junit.Assert.assertEquals; - -public class ImageSparkTransformTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void testSingleImageSparkTransform() throws Exception { - int seed = 12345; - - File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); - - SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI()); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) - .scaleImageTransform(10).cropImageTransform(5).build(); - - ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); - Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); - - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); -// System.out.println("Base 64ed array " + fromBase64); - assertEquals(1, fromBase64.size(0)); - } - - @Test - public void testBatchImageSparkTransform() throws Exception { - int seed = 12345; - - File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); - File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile(); - File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(f0.toURI()); - batch.add(f1.toURI()); - batch.add(f2.toURI()); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) - .scaleImageTransform(10).cropImageTransform(5).build(); - - ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); - Base64NDArrayBody body = imgSparkTransform.toArray(batch); - - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); -// System.out.println("Base 64ed array " + fromBase64); - assertEquals(3, fromBase64.size(0)); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java deleted file mode 100644 index 599f8eead..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -public class SingleCSVRecordTest { - - @Test(expected = IllegalArgumentException.class) - public void testVectorAssertion() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1)); - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet); - fail(singleCsvRecord.toString() + " should have thrown an exception"); - } - - @Test - public void testVectorOneHotLabel() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}})); - - //assert - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); - assertEquals(3, singleCsvRecord.getValues().size()); - - } - - @Test - public void testVectorRegression() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); - - //assert - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); - assertEquals(4, singleCsvRecord.getValues().size()); - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java deleted file mode 100644 index 3c321e583..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.common.io.ClassPathResource; - -import java.io.File; - -public class SingleImageRecordTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void testImageRecord() throws Exception { - File f = testDir.newFolder(); - new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f); - File f0 = new File(f, "class0/0.jpg"); - File f1 = new File(f, "/class1/A.jpg"); - - SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI()); - - // need jackson test? - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml deleted file mode 100644 index 8a65942db..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ /dev/null @@ -1,154 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-server_2.11 - - datavec-spark-inference-server - - - - 2.11.12 - 2.11 - 1.8 - 1.8 - - - - - org.datavec - datavec-spark-inference-model - ${datavec.version} - - - org.datavec - datavec-spark_2.11 - ${project.version} - - - org.datavec - datavec-data-image - - - joda-time - joda-time - - - org.apache.commons - commons-lang3 - - - org.hibernate - hibernate-validator - ${hibernate.version} - - - org.scala-lang - scala-library - ${scala.version} - - - org.scala-lang - scala-reflect - ${scala.version} - - - com.typesafe.play - play-java_2.11 - ${playframework.version} - - - com.google.code.findbugs - jsr305 - - - net.jodah - typetools - - - - - net.jodah - typetools - ${jodah.typetools.version} - - - com.typesafe.play - play-json_2.11 - ${playframework.version} - - - com.typesafe.play - play-server_2.11 - ${playframework.version} - - - com.typesafe.play - play_2.11 - ${playframework.version} - - - com.typesafe.play - play-netty-server_2.11 - ${playframework.version} - - - com.typesafe.akka - akka-cluster_2.11 - 2.5.23 - - - com.mashape.unirest - unirest-java - test - - - com.beust - jcommander - ${jcommander.version} - - - org.apache.spark - spark-core_2.11 - ${spark.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java deleted file mode 100644 index 9ef085515..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java +++ /dev/null @@ -1,352 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.ParameterException; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.CSVSparkTransform; -import org.datavec.spark.inference.model.model.*; -import play.BuiltInComponents; -import play.Mode; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; - -import java.io.File; -import java.io.IOException; -import java.util.Base64; -import java.util.Random; - -import static play.mvc.Results.*; - -@Slf4j -@Data -public class CSVSparkTransformServer extends SparkTransformServer { - private CSVSparkTransform transform; - - public void runMain(String[] args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch (ParameterException e) { - //User provides invalid input -> print the usage info - jcmdr.usage(); - if (jsonPath == null) - System.err.println("Json path parameter is missing."); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - if (jsonPath != null) { - String json = FileUtils.readFileToString(new File(jsonPath)); - TransformProcess transformProcess = TransformProcess.fromJson(json); - transform = new CSVSparkTransform(transformProcess); - } else { - log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" - + "to /transformprocess"); - } - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - - server = Server.forRouter(Mode.PROD, port, this::createRouter); - } - - protected Router createRouter(BuiltInComponents b){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(b); - - routingDsl.GET("/transformprocess").routingTo(req -> { - try { - if (transform == null) - return badRequest(); - return ok(transform.getTransformProcess().toJson()).as(contentType); - } catch (Exception e) { - log.error("Error in GET /transformprocess",e); - return internalServerError(e.getMessage()); - } - }); - - routingDsl.POST("/transformprocess").routingTo(req -> { - try { - TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); - setCSVTransformProcess(transformProcess); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); - } catch (Exception e) { - log.error("Error in POST /transformprocess",e); - return internalServerError(e.getMessage()); - } - }); - - routingDsl.POST("/transformincremental").routingTo(req -> { - if (isSequence(req)) { - try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincremental", e); - return internalServerError(e.getMessage()); - } - } else { - try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincremental", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transform").routingTo(req -> { - if (isSequence(req)) { - try { - SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(batch)).as(contentType); - } catch (Exception e) { - log.error("Error in /transform", e); - return internalServerError(e.getMessage()); - } - } else { - try { - BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - BatchCSVRecord batch = transform(input); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(batch)).as(contentType); - } catch (Exception e) { - log.error("Error in /transform", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transformincrementalarray").routingTo(req -> { - if (isSequence(req)) { - try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincrementalarray", e); - return internalServerError(e.getMessage()); - } - } else { - try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincrementalarray", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transformarray").routingTo(req -> { - if (isSequence(req)) { - try { - SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); - if (batchCSVRecord == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformarray", e); - return internalServerError(e.getMessage()); - } - } else { - try { - BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (batchCSVRecord == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformarray", e); - return internalServerError(e.getMessage()); - } - } - }); - - return routingDsl.build(); - } - - public static void main(String[] args) throws Exception { - new CSVSparkTransformServer().runMain(args); - } - - /** - * @param transformProcess - */ - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - this.transform = new CSVSparkTransform(transformProcess); - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @return - */ - @Override - public TransformProcess getCSVTransformProcess() { - return transform.getTransformProcess(); - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - - /** - * - */ - /** - * @param transform - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - return this.transform.transformSequenceIncremental(transform); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - return transform.transformSequence(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - return this.transform.transformSequenceArray(batchCSVRecord); - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - return this.transform.transformSequenceArrayIncremental(singleCsvRecord); - } - - /** - * @param transform - * @return - */ - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { - return this.transform.transform(transform); - } - - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - return this.transform.transform(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - return transform.transform(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - try { - return this.transform.toArray(batchCSVRecord); - } catch (IOException e) { - log.error("Error in transformArray",e); - throw new IllegalStateException("Transform array shouldn't throw exception"); - } - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - try { - return this.transform.toArray(singleCsvRecord); - } catch (IOException e) { - log.error("Error in transformArrayIncremental",e); - throw new IllegalStateException("Transform array shouldn't throw exception"); - } - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { - log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { - log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java deleted file mode 100644 index e7744ecaa..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.ParameterException; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.ImageSparkTransform; -import org.datavec.spark.inference.model.model.*; -import play.BuiltInComponents; -import play.Mode; -import play.libs.Files; -import play.mvc.Http; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; - -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static play.mvc.Results.*; - -@Slf4j -@Data -public class ImageSparkTransformServer extends SparkTransformServer { - private ImageSparkTransform transform; - - public void runMain(String[] args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch (ParameterException e) { - //User provides invalid input -> print the usage info - jcmdr.usage(); - if (jsonPath == null) - System.err.println("Json path parameter is missing."); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - if (jsonPath != null) { - String json = FileUtils.readFileToString(new File(jsonPath)); - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); - transform = new ImageSparkTransform(transformProcess); - } else { - log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" - + "to /transformprocess"); - } - - server = Server.forRouter(Mode.PROD, port, this::createRouter); - } - - protected Router createRouter(BuiltInComponents builtInComponents){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); - - routingDsl.GET("/transformprocess").routingTo(req -> { - try { - if (transform == null) - return badRequest(); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformprocess").routingTo(req -> { - try { - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); - setImageTransformProcess(transformProcess); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformincrementalarray").routingTo(req -> { - try { - SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformincrementalimage").routingTo(req -> { - try { - Http.MultipartFormData body = req.body().asMultipartFormData(); - List> files = body.getFiles(); - if (files.isEmpty() || files.get(0).getRef() == null ) { - return badRequest(); - } - - File file = files.get(0).getRef().path().toFile(); - SingleImageRecord record = new SingleImageRecord(file.toURI()); - - return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformarray").routingTo(req -> { - try { - BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformimage").routingTo(req -> { - try { - Http.MultipartFormData body = req.body().asMultipartFormData(); - List> files = body.getFiles(); - if (files.size() == 0) { - return badRequest(); - } - - List records = new ArrayList<>(); - - for (Http.MultipartFormData.FilePart filePart : files) { - Files.TemporaryFile file = filePart.getRef(); - if (file != null) { - SingleImageRecord record = new SingleImageRecord(file.path().toUri()); - records.add(record); - } - } - - BatchImageRecord batch = new BatchImageRecord(records); - - return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - return routingDsl.build(); - } - - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException(); - } - - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException(); - - } - - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException(); - - } - - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - throw new UnsupportedOperationException(); - - } - - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - this.transform = new ImageSparkTransform(imageTransformProcess); - } - - @Override - public TransformProcess getCSVTransformProcess() { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - return transform.getImageTransformProcess(); - } - - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException { - return transform.toArray(record); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException { - return transform.toArray(batch); - } - - public static void main(String[] args) throws Exception { - new ImageSparkTransformServer().runMain(args); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java deleted file mode 100644 index c89ef90cc..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.Parameter; -import com.fasterxml.jackson.databind.JsonNode; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.service.DataVecTransformService; -import org.nd4j.shade.jackson.databind.ObjectMapper; -import play.mvc.Http; -import play.server.Server; - -public abstract class SparkTransformServer implements DataVecTransformService { - @Parameter(names = {"-j", "--jsonPath"}, arity = 1) - protected String jsonPath = null; - @Parameter(names = {"-dp", "--dataVecPort"}, arity = 1) - protected int port = 9000; - @Parameter(names = {"-dt", "--dataType"}, arity = 1) - private TransformDataType transformDataType = null; - protected Server server; - protected static ObjectMapper objectMapper = new ObjectMapper(); - protected static String contentType = "application/json"; - - public abstract void runMain(String[] args) throws Exception; - - /** - * Stop the server - */ - public void stop() { - if (server != null) - server.stop(); - } - - protected boolean isSequence(Http.Request request) { - return request.hasHeader(SEQUENCE_OR_NOT_HEADER) - && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); - } - - protected String getJsonText(Http.Request request) { - JsonNode tryJson = request.body().asJson(); - if (tryJson != null) - return tryJson.toString(); - else - return request.body().asText(); - } - - public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java deleted file mode 100644 index aa4945ddb..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import lombok.Data; -import lombok.extern.slf4j.Slf4j; - -import java.io.InvalidClassException; -import java.util.Arrays; -import java.util.List; - -@Data -@Slf4j -public class SparkTransformServerChooser { - private SparkTransformServer sparkTransformServer = null; - private TransformDataType transformDataType = null; - - public void runMain(String[] args) throws Exception { - - int pos = getMatchingPosition(args, "-dt", "--dataType"); - if (pos == -1) { - log.error("no valid options"); - log.error("-dt, --dataType Options: [CSV, IMAGE]"); - throw new Exception("no valid options"); - } else { - transformDataType = TransformDataType.valueOf(args[pos + 1]); - } - - switch (transformDataType) { - case CSV: - sparkTransformServer = new CSVSparkTransformServer(); - break; - case IMAGE: - sparkTransformServer = new ImageSparkTransformServer(); - break; - default: - throw new InvalidClassException("no matching SparkTransform class"); - } - - sparkTransformServer.runMain(args); - } - - private int getMatchingPosition(String[] args, String... options) { - List optionList = Arrays.asList(options); - - for (int i = 0; i < args.length; i++) { - if (optionList.contains(args[i])) { - return i; - } - } - return -1; - } - - - public static void main(String[] args) throws Exception { - new SparkTransformServerChooser().runMain(args); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java deleted file mode 100644 index 643cd5652..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -public enum TransformDataType { - CSV, IMAGE, -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf deleted file mode 100644 index 28a4aa208..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf +++ /dev/null @@ -1,350 +0,0 @@ -# This is the main configuration file for the application. -# https://www.playframework.com/documentation/latest/ConfigFile -# ~~~~~ -# Play uses HOCON as its configuration file format. HOCON has a number -# of advantages over other config formats, but there are two things that -# can be used when modifying settings. -# -# You can include other configuration files in this main application.conf file: -#include "extra-config.conf" -# -# You can declare variables and substitute for them: -#mykey = ${some.value} -# -# And if an environment variable exists when there is no other subsitution, then -# HOCON will fall back to substituting environment variable: -#mykey = ${JAVA_HOME} - -## Akka -# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration -# https://www.playframework.com/documentation/latest/JavaAkka#Configuration -# ~~~~~ -# Play uses Akka internally and exposes Akka Streams and actors in Websockets and -# other streaming HTTP responses. -akka { - # "akka.log-config-on-start" is extraordinarly useful because it log the complete - # configuration at INFO level, including defaults and overrides, so it s worth - # putting at the very top. - # - # Put the following in your conf/logback.xml file: - # - # - # - # And then uncomment this line to debug the configuration. - # - #log-config-on-start = true -} - -## Modules -# https://www.playframework.com/documentation/latest/Modules -# ~~~~~ -# Control which modules are loaded when Play starts. Note that modules are -# the replacement for "GlobalSettings", which are deprecated in 2.5.x. -# Please see https://www.playframework.com/documentation/latest/GlobalSettings -# for more information. -# -# You can also extend Play functionality by using one of the publically available -# Play modules: https://playframework.com/documentation/latest/ModuleDirectory -play.modules { - # By default, Play will load any class called Module that is defined - # in the root package (the "app" directory), or you can define them - # explicitly below. - # If there are any built-in modules that you want to disable, you can list them here. - #enabled += my.application.Module - - # If there are any built-in modules that you want to disable, you can list them here. - #disabled += "" -} - -## Internationalisation -# https://www.playframework.com/documentation/latest/JavaI18N -# https://www.playframework.com/documentation/latest/ScalaI18N -# ~~~~~ -# Play comes with its own i18n settings, which allow the user's preferred language -# to map through to internal messages, or allow the language to be stored in a cookie. -play.i18n { - # The application languages - langs = [ "en" ] - - # Whether the language cookie should be secure or not - #langCookieSecure = true - - # Whether the HTTP only attribute of the cookie should be set to true - #langCookieHttpOnly = true -} - -## Play HTTP settings -# ~~~~~ -play.http { - ## Router - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # Define the Router object to use for this application. - # This router will be looked up first when the application is starting up, - # so make sure this is the entry point. - # Furthermore, it's assumed your route file is named properly. - # So for an application router like `my.application.Router`, - # you may need to define a router file `conf/my.application.routes`. - # Default to Routes in the root package (aka "apps" folder) (and conf/routes) - #router = my.application.Router - - ## Action Creator - # https://www.playframework.com/documentation/latest/JavaActionCreator - # ~~~~~ - #actionCreator = null - - ## ErrorHandler - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # If null, will attempt to load a class called ErrorHandler in the root package, - #errorHandler = null - - ## Filters - # https://www.playframework.com/documentation/latest/ScalaHttpFilters - # https://www.playframework.com/documentation/latest/JavaHttpFilters - # ~~~~~ - # Filters run code on every request. They can be used to perform - # common logic for all your actions, e.g. adding common headers. - # Defaults to "Filters" in the root package (aka "apps" folder) - # Alternatively you can explicitly register a class here. - #filters += my.application.Filters - - ## Session & Flash - # https://www.playframework.com/documentation/latest/JavaSessionFlash - # https://www.playframework.com/documentation/latest/ScalaSessionFlash - # ~~~~~ - session { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - - # Sets the max-age field of the cookie to 5 minutes. - # NOTE: this only sets when the browser will discard the cookie. Play will consider any - # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, - # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. - #maxAge = 300 - - # Sets the domain on the session cookie. - #domain = "example.com" - } - - flash { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - } -} - -## Netty Provider -# https://www.playframework.com/documentation/latest/SettingsNetty -# ~~~~~ -play.server.netty { - # Whether the Netty wire should be logged - #log.wire = true - - # If you run Play on Linux, you can use Netty's native socket transport - # for higher performance with less garbage. - #transport = "native" -} - -## WS (HTTP Client) -# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS -# ~~~~~ -# The HTTP client primarily used for REST APIs. The default client can be -# configured directly, but you can also create different client instances -# with customized settings. You must enable this by adding to build.sbt: -# -# libraryDependencies += ws // or javaWs if using java -# -play.ws { - # Sets HTTP requests not to follow 302 requests - #followRedirects = false - - # Sets the maximum number of open HTTP connections for the client. - #ahc.maxConnectionsTotal = 50 - - ## WS SSL - # https://www.playframework.com/documentation/latest/WsSSL - # ~~~~~ - ssl { - # Configuring HTTPS with Play WS does not require programming. You can - # set up both trustManager and keyManager for mutual authentication, and - # turn on JSSE debugging in development with a reload. - #debug.handshake = true - #trustManager = { - # stores = [ - # { type = "JKS", path = "exampletrust.jks" } - # ] - #} - } -} - -## Cache -# https://www.playframework.com/documentation/latest/JavaCache -# https://www.playframework.com/documentation/latest/ScalaCache -# ~~~~~ -# Play comes with an integrated cache API that can reduce the operational -# overhead of repeated requests. You must enable this by adding to build.sbt: -# -# libraryDependencies += cache -# -play.cache { - # If you want to bind several caches, you can bind the individually - #bindCaches = ["db-cache", "user-cache", "session-cache"] -} - -## Filters -# https://www.playframework.com/documentation/latest/Filters -# ~~~~~ -# There are a number of built-in filters that can be enabled and configured -# to give Play greater security. You must enable this by adding to build.sbt: -# -# libraryDependencies += filters -# -play.filters { - ## CORS filter configuration - # https://www.playframework.com/documentation/latest/CorsFilter - # ~~~~~ - # CORS is a protocol that allows web applications to make requests from the browser - # across different domains. - # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has - # dependencies on CORS settings. - cors { - # Filter paths by a whitelist of path prefixes - #pathPrefixes = ["/some/path", ...] - - # The allowed origins. If null, all origins are allowed. - #allowedOrigins = ["http://www.example.com"] - - # The allowed HTTP methods. If null, all methods are allowed - #allowedHttpMethods = ["GET", "POST"] - } - - ## CSRF Filter - # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter - # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter - # ~~~~~ - # Play supports multiple methods for verifying that a request is not a CSRF request. - # The primary mechanism is a CSRF token. This token gets placed either in the query string - # or body of every form submitted, and also gets placed in the users session. - # Play then verifies that both tokens are present and match. - csrf { - # Sets the cookie to be sent only over HTTPS - #cookie.secure = true - - # Defaults to CSRFErrorHandler in the root package. - #errorHandler = MyCSRFErrorHandler - } - - ## Security headers filter configuration - # https://www.playframework.com/documentation/latest/SecurityHeaders - # ~~~~~ - # Defines security headers that prevent XSS attacks. - # If enabled, then all options are set to the below configuration by default: - headers { - # The X-Frame-Options header. If null, the header is not set. - #frameOptions = "DENY" - - # The X-XSS-Protection header. If null, the header is not set. - #xssProtection = "1; mode=block" - - # The X-Content-Type-Options header. If null, the header is not set. - #contentTypeOptions = "nosniff" - - # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. - #permittedCrossDomainPolicies = "master-only" - - # The Content-Security-Policy header. If null, the header is not set. - #contentSecurityPolicy = "default-src 'self'" - } - - ## Allowed hosts filter configuration - # https://www.playframework.com/documentation/latest/AllowedHostsFilter - # ~~~~~ - # Play provides a filter that lets you configure which hosts can access your application. - # This is useful to prevent cache poisoning attacks. - hosts { - # Allow requests to example.com, its subdomains, and localhost:9000. - #allowed = [".example.com", "localhost:9000"] - } -} - -## Evolutions -# https://www.playframework.com/documentation/latest/Evolutions -# ~~~~~ -# Evolutions allows database scripts to be automatically run on startup in dev mode -# for database migrations. You must enable this by adding to build.sbt: -# -# libraryDependencies += evolutions -# -play.evolutions { - # You can disable evolutions for a specific datasource if necessary - #db.default.enabled = false -} - -## Database Connection Pool -# https://www.playframework.com/documentation/latest/SettingsJDBC -# ~~~~~ -# Play doesn't require a JDBC database to run, but you can easily enable one. -# -# libraryDependencies += jdbc -# -play.db { - # The combination of these two settings results in "db.default" as the - # default JDBC pool: - #config = "db" - #default = "default" - - # Play uses HikariCP as the default connection pool. You can override - # settings by changing the prototype: - prototype { - # Sets a fixed JDBC connection pool size of 50 - #hikaricp.minimumIdle = 50 - #hikaricp.maximumPoolSize = 50 - } -} - -## JDBC Datasource -# https://www.playframework.com/documentation/latest/JavaDatabase -# https://www.playframework.com/documentation/latest/ScalaDatabase -# ~~~~~ -# Once JDBC datasource is set up, you can work with several different -# database options: -# -# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick -# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA -# EBean: https://playframework.com/documentation/latest/JavaEbean -# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm -# -db { - # You can declare as many datasources as you want. - # By convention, the default datasource is named `default` - - # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database - default.driver = org.h2.Driver - default.url = "jdbc:h2:mem:play" - #default.username = sa - #default.password = "" - - # You can expose this datasource via JNDI if needed (Useful for JPA) - default.jndiName=DefaultDS - - # You can turn on SQL logging for any datasource - # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements - #default.logSql=true -} - -jpa.default=defaultPersistenceUnit - - -#Increase default maximum post length - used for remote listener functionality -#Can get response 413 with larger networks without setting this -# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead -#parsers.text.maxLength=10M -play.http.parser.maxMemoryBuffer=10M diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java deleted file mode 100644 index ab76b206e..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.spark.transform; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.spark.transform"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java deleted file mode 100644 index 8f309caff..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeNotNull; - -public class CSVSparkTransformServerNoJsonTest { - - private static CSVSparkTransformServer server; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new CSVSparkTransformServer(); - FileUtils.write(fileSave, transformProcess.toJson()); - - // Only one time - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"-dp", "9050"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.delete(); - server.stop(); - - } - - - - @Test - public void testServer() throws Exception { - assertTrue(server.getTransform() == null); - JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(transformProcess.toJson()).asJson().getBody(); - assumeNotNull(server.getTransform()); - - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - /* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); -*/ - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java deleted file mode 100644 index a3af5f2c6..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -public class CSVSparkTransformServerTest { - - private static CSVSparkTransformServer server; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new CSVSparkTransformServer(); - FileUtils.write(fileSave, transformProcess.toJson()); - // Only one time - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.deleteOnExit(); - server.stop(); - - } - - - - @Test - public void testServer() throws Exception { - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - - - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java deleted file mode 100644 index 12f754acd..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.server.ImageSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; - -public class ImageSparkTransformServerTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - private static ImageSparkTransformServer server; - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new ImageSparkTransformServer(); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) - .scaleImageTransform(10).cropImageTransform(5).build(); - - FileUtils.write(fileSave, imgTransformProcess.toJson()); - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.deleteOnExit(); - server.stop(); - - } - - @Test - public void testImageServer() throws Exception { - SingleImageRecord record = - new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asJson().getBody(); - Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); - - JsonNode jsonNodeBatch = - Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batch).asJson().getBody(); - Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(batch) - .asObject(Base64NDArrayBody.class).getBody(); - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - - INDArray batchResult = getNDArray(jsonNodeBatch); - assertEquals(3, batchResult.size(0)); - -// System.out.println(array); - } - - @Test - public void testImageServerMultipart() throws Exception { - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") - .header("accept", "application/json") - .field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile()) - .field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile()) - .field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile()) - .asJson().getBody(); - - - INDArray batchResult = getNDArray(jsonNode); - assertEquals(3, batchResult.size(0)); - -// System.out.println(batchResult); - } - - @Test - public void testImageServerSingleMultipart() throws Exception { - File f = testDir.newFolder(); - File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f); - - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") - .header("accept", "application/json") - .field("file1", imgFile) - .asJson().getBody(); - - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - -// System.out.println(result); - } - - public INDArray getNDArray(JsonNode node) throws IOException { - return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java deleted file mode 100644 index 831dd24f4..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.server.SparkTransformServerChooser; -import org.datavec.spark.inference.server.TransformDataType; -import org.datavec.spark.inference.model.model.*; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; - -public class SparkTransformServerTest { - private static SparkTransformServerChooser serverChooser; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build(); - - private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json"); - private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - serverChooser = new SparkTransformServerChooser(); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) - .scaleImageTransform(10).cropImageTransform(5).build(); - - FileUtils.write(imageTransformFile, imgTransformProcess.toJson()); - - FileUtils.write(csvTransformFile, transformProcess.toJson()); - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - - } - - @AfterClass - public static void after() throws Exception { - imageTransformFile.deleteOnExit(); - csvTransformFile.deleteOnExit(); - } - - @Test - public void testImageServer() throws Exception { - serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt", - TransformDataType.IMAGE.toString()}); - - SingleImageRecord record = - new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asJson().getBody(); - Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); - - JsonNode jsonNodeBatch = - Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batch).asJson().getBody(); - Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(batch) - .asObject(Base64NDArrayBody.class).getBody(); - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - - INDArray batchResult = getNDArray(jsonNodeBatch); - assertEquals(3, batchResult.size(0)); - - serverChooser.getSparkTransformServer().stop(); - } - - @Test - public void testCSVServer() throws Exception { - serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt", - TransformDataType.CSV.toString()}); - - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - serverChooser.getSparkTransformServer().stop(); - } - - public INDArray getNDArray(JsonNode node) throws IOException { - return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf deleted file mode 100644 index dbac92d83..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf +++ /dev/null @@ -1,6 +0,0 @@ -play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule -play.modules.enabled += io.skymind.skil.service.PredictionModule -play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk -play.server.pidfile.path=/tmp/RUNNING_PID - -play.server.http.port = 9600 diff --git a/datavec/datavec-spark-inference-parent/pom.xml b/datavec/datavec-spark-inference-parent/pom.xml deleted file mode 100644 index abf3f3b0d..000000000 --- a/datavec/datavec-spark-inference-parent/pom.xml +++ /dev/null @@ -1,68 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-parent - pom - - datavec-spark-inference-parent - - - datavec-spark-inference-server - datavec-spark-inference-client - datavec-spark-inference-model - - - - - - org.datavec - datavec-data-image - ${datavec.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/pom.xml b/datavec/pom.xml index 4142db170..d1c46077f 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -45,7 +45,6 @@ datavec-data datavec-spark datavec-local - datavec-spark-inference-parent datavec-jdbc datavec-excel datavec-arrow diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 6efc26d34..27caa6718 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -163,12 +163,6 @@ oshi-core ${oshi.version} - - org.nd4j - nd4j-native - ${project.version} - test - diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java index 271dcd4a4..0a03b1dc7 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasConvolutionUtils.java @@ -59,7 +59,7 @@ public class KerasConvolutionUtils { List stridesList = (List) innerConfig.get(conf.getLAYER_FIELD_CONVOLUTION_STRIDES()); strides = ArrayUtil.toArray(stridesList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ if ((int) layerConfig.get("keras_version") == 2) { @SuppressWarnings("unchecked") List stridesList = (List) innerConfig.get(conf.getLAYER_FIELD_SUBSAMPLE_LENGTH()); @@ -163,7 +163,7 @@ public class KerasConvolutionUtils { * @throws InvalidKerasConfigurationException Invalid Keras configuration */ static int[] getUpsamplingSizeFromConfig(Map layerConfig, int dimension, - KerasLayerConfiguration conf) + KerasLayerConfiguration conf) throws InvalidKerasConfigurationException { Map innerConfig = KerasLayerUtils.getInnerLayerConfigFromConfig(layerConfig, conf); int[] size; @@ -200,7 +200,7 @@ public class KerasConvolutionUtils { if (kerasMajorVersion != 2) { if (innerConfig.containsKey(conf.getLAYER_FIELD_NB_ROW()) && dimension == 2 && innerConfig.containsKey(conf.getLAYER_FIELD_NB_COL())) { - /* 2D Convolutional layers. */ + /* 2D Convolutional layers. */ List kernelSizeList = new ArrayList<>(); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_ROW())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_NB_COL())); @@ -208,23 +208,23 @@ public class KerasConvolutionUtils { } else if (innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_1()) && dimension == 3 && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_2()) && innerConfig.containsKey(conf.getLAYER_FIELD_3D_KERNEL_3())) { - /* 3D Convolutional layers. */ + /* 3D Convolutional layers. */ List kernelSizeList = new ArrayList<>(); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_1())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_2())); kernelSizeList.add((Integer) innerConfig.get(conf.getLAYER_FIELD_3D_KERNEL_3())); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ int filterLength = (int) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); kernelSize = new int[]{filterLength}; } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { - /* 2D/3D Pooling layers. */ + /* 2D/3D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { - /* 1D Pooling layers. */ + /* 1D Pooling layers. */ int poolSize1D = (int) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); kernelSize = new int[]{poolSize1D}; } else { @@ -242,17 +242,17 @@ public class KerasConvolutionUtils { List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_KERNEL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_FILTER_LENGTH()) && dimension == 1) { - /* 1D Convolutional layers. */ + /* 1D Convolutional layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_FILTER_LENGTH()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_SIZE()) && dimension >= 2) { - /* 2D Pooling layers. */ + /* 2D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); } else if (innerConfig.containsKey(conf.getLAYER_FIELD_POOL_1D_SIZE()) && dimension == 1) { - /* 1D Pooling layers. */ + /* 1D Pooling layers. */ @SuppressWarnings("unchecked") List kernelSizeList = (List) innerConfig.get(conf.getLAYER_FIELD_POOL_1D_SIZE()); kernelSize = ArrayUtil.toArray(kernelSizeList); @@ -364,16 +364,17 @@ public class KerasConvolutionUtils { } if ((paddingNoCast.size() == dimension) && !isNested) { - for (int i=0; i < dimension; i++) + for (int i = 0; i < dimension; i++) paddingList.add((int) paddingNoCast.get(i)); padding = ArrayUtil.toArray(paddingList); } else if ((paddingNoCast.size() == dimension) && isNested) { - for (int j=0; j < dimension; j++) { + for (int j = 0; j < dimension; j++) { @SuppressWarnings("unchecked") - List item = (List) paddingNoCast.get(0); + List item = (List) paddingNoCast.get(j); paddingList.add((item.get(0))); paddingList.add((item.get(1))); } + padding = ArrayUtil.toArray(paddingList); } else { throw new InvalidKerasConfigurationException("Found Keras ZeroPadding" + dimension diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java index b4df34c5b..66d49d37a 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/convolutional/KerasCropping2D.java @@ -29,6 +29,8 @@ import org.deeplearning4j.nn.conf.layers.convolutional.Cropping2D; import org.deeplearning4j.nn.modelimport.keras.KerasLayer; import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException; import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException; +import org.nd4j.common.util.ArrayUtil; +import org.nd4j.linalg.api.ndarray.INDArray; import java.util.Map; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml deleted file mode 100644 index ee029d09f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ /dev/null @@ -1,143 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbor-server - jar - - deeplearning4j-nearestneighbor-server - - - 1.8 - - - - - org.deeplearning4j - deeplearning4j-nearestneighbors-model - ${project.version} - - - org.deeplearning4j - deeplearning4j-core - ${project.version} - - - io.vertx - vertx-core - ${vertx.version} - - - io.vertx - vertx-web - ${vertx.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - test - - - org.deeplearning4j - deeplearning4j-nearestneighbors-client - ${project.version} - test - - - com.beust - jcommander - ${jcommander.version} - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - -Dfile.encoding=UTF-8 -Xmx8g - - - *.java - **/*.java - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${java.compile.version} - ${java.compile.version} - - - - - - - - test-nd4j-native - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - test-nd4j-cuda-11.0 - - - org.nd4j - nd4j-cuda-11.0 - ${project.version} - test - - - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java deleted file mode 100644 index 88f3a7b46..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.server; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.deeplearning4j.clustering.vptree.VPTree; -import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; -import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.ArrayList; -import java.util.List; - -@AllArgsConstructor -@Builder -public class NearestNeighbor { - private NearestNeighborRequest record; - private VPTree tree; - private INDArray points; - - public List search() { - INDArray input = points.slice(record.getInputIndex()); - List results = new ArrayList<>(); - if (input.isVector()) { - List add = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(input, record.getK(), add, distances); - - if (add.size() != distances.size()) { - throw new IllegalStateException( - String.format("add.size == %d != %d == distances.size", - add.size(), distances.size())); - } - - for (int i=0; i print the usage info - jcmdr.usage(); - if (r.ndarrayPath == null) - log.error("Json path parameter is missing (null)"); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - instanceArgs = r; - try { - Vertx vertx = Vertx.vertx(); - vertx.deployVerticle(NearestNeighborsServer.class.getName()); - } catch (Throwable t){ - log.error("Error in NearestNeighboursServer run method",t); - } - } - - @Override - public void start() throws Exception { - instance = this; - - String[] pathArr = instanceArgs.ndarrayPath.split(","); - //INDArray[] pointsArr = new INDArray[pathArr.length]; - // first of all we reading shapes of saved eariler files - int rows = 0; - int cols = 0; - for (int i = 0; i < pathArr.length; i++) { - DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); - - log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), - Shape.size(shape, 1)); - - if (Shape.rank(shape) != 2) - throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); - - rows += Shape.size(shape, 0); - - if (cols == 0) - cols = Shape.size(shape, 1); - else if (cols != Shape.size(shape, 1)) - throw new DL4JInvalidInputException( - "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); - } - - final List labels = new ArrayList<>(); - if (instanceArgs.labelsPath != null) { - String[] labelsPathArr = instanceArgs.labelsPath.split(","); - for (int i = 0; i < labelsPathArr.length; i++) { - labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); - } - } - if (!labels.isEmpty() && labels.size() != rows) - throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size())); - - final INDArray points = Nd4j.createUninitialized(rows, cols); - - int lastPosition = 0; - for (int i = 0; i < pathArr.length; i++) { - log.info("Loading chunk {} of {}", i + 1, pathArr.length); - INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i])); - - points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr); - lastPosition += pointsArr.rows(); - - // let's ensure we don't bring too much stuff in next loop - System.gc(); - } - - VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - Router r = Router.router(vertx); - r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all - createRoutes(r, labels, tree, points); - - vertx.createHttpServer() - .requestHandler(r) - .listen(instanceArgs.port); - } - - private void createRoutes(Router r, List labels, VPTree tree, INDArray points){ - - r.post("/knn").handler(rc -> { - try { - String json = rc.getBodyAsJson().encode(); - NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); - - NearestNeighbor nearestNeighbor = - NearestNeighbor.builder().points(points).record(record).tree(tree).build(); - - if (record == null) { - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); - return; - } - - NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); - - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(results)); - return; - } catch (Throwable e) { - log.error("Error in POST /knn",e); - rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) - .end("Error parsing request - " + e.getMessage()); - return; - } - }); - - r.post("/knnnew").handler(rc -> { - try { - String json = rc.getBodyAsJson().encode(); - Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); - if (record == null) { - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); - return; - } - - INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); - List results; - List distances; - - if (record.isForceFillK()) { - VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr); - vpTreeFillSearch.search(); - results = vpTreeFillSearch.getResults(); - distances = vpTreeFillSearch.getDistances(); - } else { - results = new ArrayList<>(); - distances = new ArrayList<>(); - tree.search(arr, record.getK(), results, distances); - } - - if (results.size() != distances.size()) { - rc.response() - .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) - .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); - return; - } - - List nnResult = new ArrayList<>(); - for (int i=0; i results = nearestNeighbor.search(); - assertEquals(1, results.get(0).getIndex()); - assertEquals(2, results.size()); - - assertEquals(1.0, results.get(0).getDistance(), 1e-4); - assertEquals(4.0, results.get(1).getDistance(), 1e-4); - } - - @Test - public void testNearestNeighborInverted() { - double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; - INDArray arr = Nd4j.create(data); - - VPTree vpTree = new VPTree(arr, true); - NearestNeighborRequest request = new NearestNeighborRequest(); - request.setK(2); - request.setInputIndex(0); - NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); - List results = nearestNeighbor.search(); - assertEquals(2, results.get(0).getIndex()); - assertEquals(2, results.size()); - - assertEquals(-4.0, results.get(0).getDistance(), 1e-4); - assertEquals(-1.0, results.get(1).getDistance(), 1e-4); - } - - @Test - public void vpTreeTest() throws Exception { - INDArray matrix = Nd4j.rand(new int[] {400,10}); - INDArray rowVector = matrix.getRow(70); - INDArray resultArr = Nd4j.zeros(400,1); - Executor executor = Executors.newSingleThreadExecutor(); - VPTree vpTree = new VPTree(matrix); - System.out.println("Ran!"); - } - - - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - - @Test - public void testServer() throws Exception { - int localPort = getAvailablePort(); - Nd4j.getRandom().setSeed(7); - INDArray rand = Nd4j.randn(10, 5); - File writeToTmp = testDir.newFile(); - writeToTmp.deleteOnExit(); - BinarySerde.writeArrayToDisk(rand, writeToTmp); - NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", - String.valueOf(localPort)); - - Thread.sleep(3000); - - NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); - NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); - assertEquals(5, result.getResults().size()); - NearestNeighborsServer.getInstance().stop(); - } - - - - @Test - public void testFullSearch() throws Exception { - int numRows = 1000; - int numCols = 100; - int numNeighbors = 42; - INDArray points = Nd4j.rand(numRows, numCols); - VPTree tree = new VPTree(points); - INDArray query = Nd4j.rand(new int[] {1, numCols}); - VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query); - fillSearch.search(); - List results = fillSearch.getResults(); - List distances = fillSearch.getDistances(); - assertEquals(numNeighbors, distances.size()); - assertEquals(numNeighbors, results.size()); - } - - @Test - public void testDistances() { - - INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}}); - INDArray record = Nd4j.create(new float[][]{{7, 6}}); - VPTree vpTree = new VPTree(indArray, "euclidean", false); - VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record); - vpTreeFillSearch.search(); - //System.out.println(vpTreeFillSearch.getResults()); - System.out.println(vpTreeFillSearch.getDistances()); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml deleted file mode 100644 index 7e0af0fa1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - logs/application.log - - %logger{15} - %message%n%xException{5} - - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml deleted file mode 100644 index 55d7b83f9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-client - jar - - deeplearning4j-nearestneighbors-client - - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - org.deeplearning4j - deeplearning4j-nearestneighbors-model - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java deleted file mode 100644 index 570e75bf9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.client; - -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.request.HttpRequest; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; -import lombok.val; -import org.deeplearning4j.nearestneighbor.model.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonProcessingException; - -import java.io.IOException; - -@AllArgsConstructor -public class NearestNeighborsClient { - - private String url; - @Setter - @Getter - protected String authToken; - - public NearestNeighborsClient(String url){ - this(url, null); - } - - static { - // Only one time - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - }); - } - - - /** - * Runs knn on the given index - * with the given k (note that this is for data - * already within the existing dataset not new data) - * @param index the index of the - * EXISTING ndarray - * to run a search on - * @param k the number of results - * @return - * @throws Exception - */ - public NearestNeighborsResults knn(int index, int k) throws Exception { - NearestNeighborRequest request = new NearestNeighborRequest(); - request.setInputIndex(index); - request.setK(k); - val req = Unirest.post(url + "/knn"); - req.header("accept", "application/json") - .header("Content-Type", "application/json").body(request); - addAuthHeader(req); - - NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); - return ret; - } - - /** - * Run a k nearest neighbors search - * on a NEW data point - * @param k the number of results - * to retrieve - * @param arr the array to run the search on. - * Note that this must be a row vector - * @return - * @throws Exception - */ - public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception { - Base64NDArrayBody base64NDArrayBody = - Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); - - val req = Unirest.post(url + "/knnnew"); - req.header("accept", "application/json") - .header("Content-Type", "application/json").body(base64NDArrayBody); - addAuthHeader(req); - - NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); - - return ret; - } - - - /** - * Add the specified authentication header to the specified HttpRequest - * - * @param request HTTP Request to add the authentication header to - */ - protected HttpRequest addAuthHeader(HttpRequest request) { - if (authToken != null) { - request.header("authorization", "Bearer " + authToken); - } - - return request; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml deleted file mode 100644 index 09a72628e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-model - jar - - deeplearning4j-nearestneighbors-model - - - - org.projectlombok - lombok - ${lombok.version} - provided - - - org.nd4j - nd4j-api - ${nd4j.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java deleted file mode 100644 index c68f48ebe..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@AllArgsConstructor -@NoArgsConstructor -@Builder -public class Base64NDArrayBody implements Serializable { - private String ndarray; - private int k; - private boolean forceFillK; -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java deleted file mode 100644 index f2a9475a1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class BatchRecord implements Serializable { - private List records; - - /** - * Add a record - * @param record - */ - public void add(CSVRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static BatchRecord fromDataSet(DataSet dataSet) { - BatchRecord batchRecord = new BatchRecord(); - for (int i = 0; i < dataSet.numExamples(); i++) { - batchRecord.add(CSVRecord.fromRow(dataSet.get(i))); - } - - return batchRecord; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java deleted file mode 100644 index ef642bf0d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class CSVRecord implements Serializable { - private String[] values; - - /** - * Instantiate a csv record from a vector - * given either an input dataset and a - * one hot matrix, the index will be appended to - * the end of the record, or for regression - * it will append all values in the labels - * @param row the input vectors - * @return the record from this {@link DataSet} - */ - public static CSVRecord fromRow(DataSet row) { - if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) - throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); - if (!row.getLabels().isVector() && !row.getLabels().isScalar()) - throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); - //classification - CSVRecord record; - int idx = 0; - if (row.getLabels().sumNumber().doubleValue() == 1.0) { - String[] values = new String[row.getFeatures().columns() + 1]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - int maxIdx = 0; - for (int i = 0; i < row.getLabels().length(); i++) { - if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { - maxIdx = i; - } - } - - values[idx++] = String.valueOf(maxIdx); - record = new CSVRecord(values); - } - //regression (any number of values) - else { - String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - for (int i = 0; i < row.getLabels().length(); i++) { - values[idx++] = String.valueOf(row.getLabels().getDouble(i)); - } - - - record = new CSVRecord(values); - - } - return record; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java deleted file mode 100644 index 5044c6b35..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.Data; - -import java.io.Serializable; - -@Data -public class NearestNeighborRequest implements Serializable { - private int k; - private int inputIndex; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java deleted file mode 100644 index 768b0dfc9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -@Data -@AllArgsConstructor -@NoArgsConstructor -public class NearestNeighborsResult { - public NearestNeighborsResult(int index, double distance) { - this(index, distance, null); - } - - private int index; - private double distance; - private String label; -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java deleted file mode 100644 index d95c68fb6..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; -import java.util.List; - -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class NearestNeighborsResults implements Serializable { - private List results; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml deleted file mode 100644 index 5df85229d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ /dev/null @@ -1,103 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - nearestneighbor-core - jar - - nearestneighbor-core - - - - org.nd4j - nd4j-api - ${nd4j.version} - - - junit - junit - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-nn - ${project.version} - - - org.deeplearning4j - deeplearning4j-datasets - ${project.version} - test - - - joda-time - joda-time - 2.10.3 - test - - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - - - - - test-nd4j-native - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - test-nd4j-cuda-11.0 - - - org.nd4j - nd4j-cuda-11.0 - ${project.version} - test - - - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java deleted file mode 100755 index e7e467ad3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.clustering.cluster.Cluster; -import org.deeplearning4j.clustering.cluster.ClusterSet; -import org.deeplearning4j.clustering.cluster.ClusterUtils; -import org.deeplearning4j.clustering.cluster.Point; -import org.deeplearning4j.clustering.info.ClusterSetInfo; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.deeplearning4j.clustering.iteration.IterationInfo; -import org.deeplearning4j.clustering.strategy.ClusteringStrategy; -import org.deeplearning4j.clustering.strategy.ClusteringStrategyType; -import org.deeplearning4j.clustering.strategy.OptimisationStrategy; -import org.deeplearning4j.clustering.util.MultiThreadUtils; -import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutorService; - -@Slf4j -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable { - - private static final long serialVersionUID = 338231277453149972L; - - private ClusteringStrategy clusteringStrategy; - private IterationHistory iterationHistory; - private int currentIteration = 0; - private ClusterSet clusterSet; - private List initialPoints; - private transient ExecutorService exec; - private boolean useKmeansPlusPlus; - - - protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { - this.clusteringStrategy = clusteringStrategy; - this.exec = MultiThreadUtils.newExecutorService(); - this.useKmeansPlusPlus = useKmeansPlusPlus; - } - - /** - * - * @param clusteringStrategy - * @return - */ - public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { - return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus); - } - - /** - * - * @param points - * @return - */ - public ClusterSet applyTo(List points) { - resetState(points); - initClusters(useKmeansPlusPlus); - iterations(); - return clusterSet; - } - - private void resetState(List points) { - this.iterationHistory = new IterationHistory(); - this.currentIteration = 0; - this.clusterSet = null; - this.initialPoints = points; - } - - /** Run clustering iterations until a - * termination condition is hit. - * This is done by first classifying all points, - * and then updating cluster centers based on - * those classified points - */ - private void iterations() { - int iterationCount = 0; - while ((clusteringStrategy.getTerminationCondition() != null - && !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory)) - || iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) { - currentIteration++; - removePoints(); - classifyPoints(); - applyClusteringStrategy(); - log.trace("Completed clustering iteration {}", ++iterationCount); - } - } - - protected void classifyPoints() { - //Classify points. This also adds each point to the ClusterSet - ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec); - //Update the cluster centers, based on the points within each cluster - ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec); - iterationHistory.getIterationsInfos().put(currentIteration, - new IterationInfo(currentIteration, clusterSetInfo)); - } - - /** - * Initialize the - * cluster centers at random - */ - protected void initClusters(boolean kMeansPlusPlus) { - log.info("Generating initial clusters"); - List points = new ArrayList<>(initialPoints); - - //Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly) - val random = Nd4j.getRandom(); - Distance distanceFn = clusteringStrategy.getDistanceFunction(); - int initialClusterCount = clusteringStrategy.getInitialClusterCount(); - clusterSet = new ClusterSet(distanceFn, - clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()}); - clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size()))); - - - //dxs: distances between - // each point and nearest cluster to that point - INDArray dxs = Nd4j.create(points.size()); - dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE); - - //Generate the initial cluster centers, by randomly selecting a point between 0 and max distance - //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster - while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) { - dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec); - double summed = Nd4j.sum(dxs).getDouble(0); - double r = kMeansPlusPlus ? random.nextDouble() * summed: - random.nextFloat() * dxs.maxNumber().doubleValue(); - - for (int i = 0; i < dxs.length(); i++) { - double distance = dxs.getDouble(i); - Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " + - "function must return values >= 0, got distance %s for function s", distance, distanceFn); - if (dxs.getDouble(i) >= r) { - clusterSet.addNewClusterWithCenter(points.remove(i)); - dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i)); - break; - } - } - } - - ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet); - iterationHistory.getIterationsInfos().put(currentIteration, - new IterationInfo(currentIteration, initialClusterSetInfo)); - } - - - protected void applyClusteringStrategy() { - if (!isStrategyApplicableNow()) - return; - - ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); - if (!clusteringStrategy.isAllowEmptyClusters()) { - int removedCount = removeEmptyClusters(clusterSetInfo); - if (removedCount > 0) { - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); - - if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) - && clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) { - int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo, - clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec); - if (splitCount > 0) - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); - } - } - } - if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) - optimize(); - } - - protected void optimize() { - ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); - OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy; - boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec); - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied); - } - - private boolean isStrategyApplicableNow() { - return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0 - && clusteringStrategy.isOptimizationApplicableNow(iterationHistory); - } - - protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) { - List removedClusters = clusterSet.removeEmptyClusters(); - clusterSetInfo.removeClusterInfos(removedClusters); - return removedClusters.size(); - } - - protected void removePoints() { - clusterSet.removePoints(); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java deleted file mode 100644 index 02ac17f39..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -import org.deeplearning4j.clustering.cluster.ClusterSet; -import org.deeplearning4j.clustering.cluster.Point; - -import java.util.List; - -public interface ClusteringAlgorithm { - - /** - * Apply a clustering - * algorithm for a given result - * @param points - * @return - */ - ClusterSet applyTo(List points); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java deleted file mode 100644 index 657df3dfa..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -public enum Distance { - EUCLIDEAN("euclidean"), - COSINE_DISTANCE("cosinedistance"), - COSINE_SIMILARITY("cosinesimilarity"), - MANHATTAN("manhattan"), - DOT("dot"), - JACCARD("jaccard"), - HAMMING("hamming"); - - private String functionName; - private Distance(String name) { - functionName = name; - } - - @Override - public String toString() { - return functionName; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java deleted file mode 100644 index 8a39d8bc3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -public class CentersHolder { - private INDArray centers; - private long index = 0; - - protected transient ReduceOp op; - protected ArgMin imin; - protected transient INDArray distances; - protected transient INDArray argMin; - - private long rows, cols; - - public CentersHolder(long rows, long cols) { - this.rows = rows; - this.cols = cols; - } - - public INDArray getCenters() { - return this.centers; - } - - public synchronized void addCenter(INDArray pointView) { - if (centers == null) - this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols}); - - centers.putRow(index++, pointView); - } - - public synchronized Pair getCenterByMinDistance(Point point, Distance distanceFunction) { - if (distances == null) - distances = Nd4j.create(centers.dataType(), centers.rows()); - - if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); - - if (op == null) { - op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new ArgMin(distances, argMin); - op.setZ(distances); - } - - op.setY(point.getArray()); - - Nd4j.getExecutioner().exec(op); - Nd4j.getExecutioner().exec(imin); - - Pair result = new Pair<>(); - result.setFirst(distances.getDouble(argMin.getLong(0))); - result.setSecond(argMin.getLong(0)); - return result; - } - - public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) { - if (distances == null) - distances = Nd4j.create(centers.dataType(), centers.rows()); - - if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); - - if (op == null) { - op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new ArgMin(distances, argMin); - op.setZ(distances); - } - - op.setY(point.getArray()); - - Nd4j.getExecutioner().exec(op); - Nd4j.getExecutioner().exec(imin); - - System.out.println(distances); - return distances; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java deleted file mode 100644 index 7f4f221e5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.Data; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -@Data -public class Cluster implements Serializable { - - private String id = UUID.randomUUID().toString(); - private String label; - - private Point center; - private List points = Collections.synchronizedList(new ArrayList()); - private boolean inverse = false; - private Distance distanceFunction; - - public Cluster() { - super(); - } - - /** - * - * @param center - * @param distanceFunction - */ - public Cluster(Point center, Distance distanceFunction) { - this(center, false, distanceFunction); - } - - /** - * - * @param center - * @param distanceFunction - */ - public Cluster(Point center, boolean inverse, Distance distanceFunction) { - this.distanceFunction = distanceFunction; - this.inverse = inverse; - setCenter(center); - } - - /** - * Get the distance to the given - * point from the cluster - * @param point the point to get the distance for - * @return - */ - public double getDistanceToCenter(Point point) { - return Nd4j.getExecutioner().execAndReturn( - ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray())) - .getFinalResult().doubleValue(); - } - - /** - * Add a point to the cluster - * @param point - */ - public void addPoint(Point point) { - addPoint(point, true); - } - - /** - * Add a point to the cluster - * @param point the point to add - * @param moveClusterCenter whether to update - * the cluster centroid or not - */ - public void addPoint(Point point, boolean moveClusterCenter) { - if (moveClusterCenter) { - if (isInverse()) { - center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1); - } else { - center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1); - } - } - - getPoints().add(point); - } - - /** - * Clear out the ponits - */ - public void removePoints() { - if (getPoints() != null) - getPoints().clear(); - } - - /** - * Whether the cluster is empty or not - * @return - */ - public boolean isEmpty() { - return points == null || points.isEmpty(); - } - - /** - * Return the point with the given id - * @param id - * @return - */ - public Point getPoint(String id) { - for (Point point : points) - if (id.equals(point.getId())) - return point; - return null; - } - - /** - * Remove the point and return it - * @param id - * @return - */ - public Point removePoint(String id) { - Point removePoint = null; - for (Point point : points) - if (id.equals(point.getId())) - removePoint = point; - if (removePoint != null) - points.remove(removePoint); - return removePoint; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java deleted file mode 100644 index dabfdc7a4..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.Data; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; -import java.util.*; - -@Data -public class ClusterSet implements Serializable { - - private Distance distanceFunction; - private List clusters; - private CentersHolder centersHolder; - private Map pointDistribution; - private boolean inverse; - - public ClusterSet(boolean inverse) { - this(null, inverse, null); - } - - public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) { - this.distanceFunction = distanceFunction; - this.inverse = inverse; - this.clusters = Collections.synchronizedList(new ArrayList()); - this.pointDistribution = Collections.synchronizedMap(new HashMap()); - if (shape != null) - this.centersHolder = new CentersHolder(shape[0], shape[1]); - } - - - public boolean isInverse() { - return inverse; - } - - /** - * - * @param center - * @return - */ - public Cluster addNewClusterWithCenter(Point center) { - Cluster newCluster = new Cluster(center, distanceFunction); - getClusters().add(newCluster); - setPointLocation(center, newCluster); - centersHolder.addCenter(center.getArray()); - return newCluster; - } - - /** - * - * @param point - * @return - */ - public PointClassification classifyPoint(Point point) { - return classifyPoint(point, true); - } - - /** - * - * @param points - */ - public void classifyPoints(List points) { - classifyPoints(points, true); - } - - /** - * - * @param points - * @param moveClusterCenter - */ - public void classifyPoints(List points, boolean moveClusterCenter) { - for (Point point : points) - classifyPoint(point, moveClusterCenter); - } - - /** - * - * @param point - * @param moveClusterCenter - * @return - */ - public PointClassification classifyPoint(Point point, boolean moveClusterCenter) { - Pair nearestCluster = nearestCluster(point); - Cluster newCluster = nearestCluster.getKey(); - boolean locationChange = isPointLocationChange(point, newCluster); - addPointToCluster(point, newCluster, moveClusterCenter); - return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange); - } - - private boolean isPointLocationChange(Point point, Cluster newCluster) { - if (!getPointDistribution().containsKey(point.getId())) - return true; - return !getPointDistribution().get(point.getId()).equals(newCluster.getId()); - } - - private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) { - cluster.addPoint(point, moveClusterCenter); - setPointLocation(point, cluster); - } - - private void setPointLocation(Point point, Cluster cluster) { - pointDistribution.put(point.getId(), cluster.getId()); - } - - - /** - * - * @param point - * @return - */ - public Pair nearestCluster(Point point) { - - /*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE; - - double currentDistance; - for (Cluster cluster : getClusters()) { - currentDistance = cluster.getDistanceToCenter(point); - if (isInverse()) { - if (currentDistance > minDistance) { - minDistance = currentDistance; - nearestCluster = cluster; - } - } else { - if (currentDistance < minDistance) { - minDistance = currentDistance; - nearestCluster = cluster; - } - } - - }*/ - - Pair nearestCenterData = centersHolder. - getCenterByMinDistance(point, distanceFunction); - Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue()); - double minDistance = nearestCenterData.getFirst(); - return Pair.of(nearestCluster, minDistance); - } - - /** - * - * @param m1 - * @param m2 - * @return - */ - public double getDistance(Point m1, Point m2) { - return Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray())) - .getFinalResult().doubleValue(); - } - - /** - * - * @param point - * @return - */ - /*public double getDistanceFromNearestCluster(Point point) { - return nearestCluster(point).getValue(); - }*/ - - - /** - * - * @param clusterId - * @return - */ - public String getClusterCenterId(String clusterId) { - Point clusterCenter = getClusterCenter(clusterId); - return clusterCenter == null ? null : clusterCenter.getId(); - } - - /** - * - * @param clusterId - * @return - */ - public Point getClusterCenter(String clusterId) { - Cluster cluster = getCluster(clusterId); - return cluster == null ? null : cluster.getCenter(); - } - - /** - * - * @param id - * @return - */ - public Cluster getCluster(String id) { - for (int i = 0, j = clusters.size(); i < j; i++) - if (id.equals(clusters.get(i).getId())) - return clusters.get(i); - return null; - } - - /** - * - * @return - */ - public int getClusterCount() { - return getClusters() == null ? 0 : getClusters().size(); - } - - /** - * - */ - public void removePoints() { - for (Cluster cluster : getClusters()) - cluster.removePoints(); - } - - /** - * - * @param count - * @return - */ - public List getMostPopulatedClusters(int count) { - List mostPopulated = new ArrayList<>(clusters); - Collections.sort(mostPopulated, new Comparator() { - public int compare(Cluster o1, Cluster o2) { - return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); - } - }); - return mostPopulated.subList(0, count); - } - - /** - * - * @return - */ - public List removeEmptyClusters() { - List emptyClusters = new ArrayList<>(); - for (Cluster cluster : clusters) - if (cluster.isEmpty()) - emptyClusters.add(cluster); - clusters.removeAll(emptyClusters); - return emptyClusters; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java deleted file mode 100644 index ac1786538..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java +++ /dev/null @@ -1,531 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.info.ClusterInfo; -import org.deeplearning4j.clustering.info.ClusterSetInfo; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; -import org.deeplearning4j.clustering.strategy.OptimisationStrategy; -import org.deeplearning4j.clustering.util.MathUtils; -import org.deeplearning4j.clustering.util.MultiThreadUtils; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.*; -import java.util.concurrent.ExecutorService; - -@NoArgsConstructor(access = AccessLevel.PRIVATE) -@Slf4j -public class ClusterUtils { - - /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */ - public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List points, - ExecutorService executorService) { - final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true); - - List tasks = new ArrayList<>(); - for (final Point point : points) { - //tasks.add(new Runnable() { - // public void run() { - try { - PointClassification result = classifyPoint(clusterSet, point); - if (result.isNewLocation()) - clusterSetInfo.getPointLocationChange().incrementAndGet(); - clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter() - .put(point.getId(), result.getDistanceFromCenter()); - } catch (Throwable t) { - log.warn("Error classifying point", t); - } - // } - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - return clusterSetInfo; - } - - public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) { - return clusterSet.classifyPoint(point, false); - } - - public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - ExecutorService executorService) { - List tasks = new ArrayList<>(); - int nClusters = clusterSet.getClusterCount(); - for (int i = 0; i < nClusters; i++) { - final Cluster cluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - // public void run() { - try { - final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); - refreshClusterCenter(cluster, clusterInfo); - deriveClusterInfoDistanceStatistics(clusterInfo); - } catch (Throwable t) { - log.warn("Error refreshing cluster centers", t); - } - // } - //}); - } - //MultiThreadUtils.parallelTasks(tasks, executorService); - } - - public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { - int pointsCount = cluster.getPoints().size(); - if (pointsCount == 0) - return; - Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length())); - for (Point point : cluster.getPoints()) { - INDArray arr = point.getArray(); - if (cluster.isInverse()) - center.getArray().subi(arr); - else - center.getArray().addi(arr); - } - center.getArray().divi(pointsCount); - cluster.setCenter(center); - } - - /** - * - * @param info - */ - public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) { - int pointCount = info.getPointDistancesFromCenter().size(); - if (pointCount == 0) - return; - - double[] distances = - ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {})); - double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances); - double total = MathUtils.sum(distances); - info.setMaxPointDistanceFromCenter(max); - info.setTotalPointDistanceFromCenter(total); - info.setAveragePointDistanceFromCenter(total / pointCount); - info.setPointDistanceFromCenterVariance(MathUtils.variance(distances)); - } - - /** - * - * @param clusterSet - * @param points - * @param previousDxs - * @param executorService - * @return - */ - public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet, - final List points, INDArray previousDxs, ExecutorService executorService) { - final int pointsCount = points.size(); - final INDArray dxs = Nd4j.create(pointsCount); - final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); - - List tasks = new ArrayList<>(); - for (int i = 0; i < pointsCount; i++) { - final int i2 = i; - //tasks.add(new Runnable() { - // public void run() { - try { - Point point = points.get(i2); - double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) - : Math.pow(newCluster.getDistanceToCenter(point), 2); - dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist); - } catch (Throwable t) { - log.warn("Error computing squared distance from nearest cluster", t); - } - // } - //}); - - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - for (int i = 0; i < pointsCount; i++) { - double previousMinDistance = previousDxs.getDouble(i); - if (clusterSet.isInverse()) { - if (dxs.getDouble(i) < previousMinDistance) { - - dxs.putScalar(i, previousMinDistance); - } - } else if (dxs.getDouble(i) > previousMinDistance) - dxs.putScalar(i, previousMinDistance); - } - - return dxs; - } - - public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet, - final List points, INDArray previousDxs) { - final int pointsCount = points.size(); - final INDArray dxs = Nd4j.create(pointsCount); - final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); - - Double sum = new Double(0); - for (int i = 0; i < pointsCount; i++) { - - Point point = points.get(i); - double dist = Math.pow(newCluster.getDistanceToCenter(point), 2); - sum += dist; - dxs.putScalar(i, sum); - } - - return dxs; - } - /** - * - * @param clusterSet - * @return - */ - public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) { - ExecutorService executor = MultiThreadUtils.newExecutorService(); - ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor); - executor.shutdownNow(); - return info; - } - - public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) { - final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true); - int clusterCount = clusterSet.getClusterCount(); - - List tasks = new ArrayList<>(); - for (int i = 0; i < clusterCount; i++) { - final Cluster cluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - // public void run() { - try { - info.getClustersInfos().put(cluster.getId(), - computeClusterInfos(cluster, clusterSet.getDistanceFunction())); - } catch (Throwable t) { - log.warn("Error computing cluster set info", t); - } - //} - //}); - } - - - //MultiThreadUtils.parallelTasks(tasks, executorService); - - //tasks = new ArrayList<>(); - for (int i = 0; i < clusterCount; i++) { - final int clusterIdx = i; - final Cluster fromCluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - //public void run() { - try { - for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { - Cluster toCluster = clusterSet.getClusters().get(k); - double distance = Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp( - clusterSet.getDistanceFunction(), - fromCluster.getCenter().getArray(), - toCluster.getCenter().getArray())) - .getFinalResult().doubleValue(); - info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(), - distance); - } - } catch (Throwable t) { - log.warn("Error computing distances", t); - } - // } - //}); - - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - - return info; - } - - /** - * - * @param cluster - * @param distanceFunction - * @return - */ - public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) { - ClusterInfo info = new ClusterInfo(cluster.isInverse(), true); - for (int i = 0, j = cluster.getPoints().size(); i < j; i++) { - Point point = cluster.getPoints().get(i); - //shouldn't need to inverse here. other parts of - //the code should interpret the "distance" or score here - double distance = Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, - cluster.getCenter().getArray(), point.getArray())) - .getFinalResult().doubleValue(); - info.getPointDistancesFromCenter().put(point.getId(), distance); - double diff = info.getTotalPointDistanceFromCenter() + distance; - info.setTotalPointDistanceFromCenter(diff); - } - - if (!cluster.getPoints().isEmpty()) - info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size()); - return info; - } - - /** - * - * @param optimization - * @param clusterSet - * @param clusterSetInfo - * @param executor - * @return - */ - public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, ExecutorService executor) { - - if (optimization.isClusteringOptimizationType( - ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) { - int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, - clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); - return splitCount > 0; - } - - if (optimization.isClusteringOptimizationType( - ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) { - int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, - clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); - return splitCount > 0; - } - - return false; - } - - /** - * - * @param clusterSet - * @param info - * @param count - * @return - */ - public static List getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, - int count) { - List clusters = new ArrayList<>(clusterSet.getClusters()); - Collections.sort(clusters, new Comparator() { - public int compare(Cluster o1, Cluster o2) { - Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter(); - Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter(); - int comp = o1TotalDistance.compareTo(o2TotalDistance); - return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp; - } - }); - - return clusters.subList(0, count); - } - - /** - * - * @param clusterSet - * @param info - * @param maximumAverageDistance - * @return - */ - public static List getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet, - final ClusterSetInfo info, double maximumAverageDistance) { - List clusters = new ArrayList<>(); - for (Cluster cluster : clusterSet.getClusters()) { - ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); - if (clusterInfo != null) { - //distances - if (clusterInfo.isInverse()) { - if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance) - clusters.add(cluster); - } else { - if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance) - clusters.add(cluster); - } - - } - - } - return clusters; - } - - /** - * - * @param clusterSet - * @param info - * @param maximumDistance - * @return - */ - public static List getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet, - final ClusterSetInfo info, double maximumDistance) { - List clusters = new ArrayList<>(); - for (Cluster cluster : clusterSet.getClusters()) { - ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); - if (clusterInfo != null) { - if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) { - clusters.add(cluster); - } else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) { - clusters.add(cluster); - - } - } - } - return clusters; - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param count - * @param executorService - * @return - */ - public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, - ExecutorService executorService) { - List clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param maxWithinClusterDistance - * @param executorService - * @return - */ - public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { - List clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, - maxWithinClusterDistance); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param maxWithinClusterDistance - * @param executorService - * @return - */ - public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { - List clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, - maxWithinClusterDistance); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param count - * @param executorService - */ - public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, - ExecutorService executorService) { - List clustersToSplit = clusterSet.getMostPopulatedClusters(count); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param clusters - * @param maxDistance - * @param executorService - */ - public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - List clusters, final double maxDistance, ExecutorService executorService) { - final Random random = new Random(); - List tasks = new ArrayList<>(); - for (final Cluster cluster : clusters) { - tasks.add(new Runnable() { - public void run() { - try { - ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); - List fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance); - int rank = Math.min(fartherPoints.size(), 3); - String pointId = fartherPoints.get(random.nextInt(rank)); - Point point = cluster.removePoint(pointId); - clusterSet.addNewClusterWithCenter(point); - } catch (Throwable t) { - log.warn("Error splitting clusters", t); - } - } - }); - } - MultiThreadUtils.parallelTasks(tasks, executorService); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param clusters - * @param executorService - */ - public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - List clusters, ExecutorService executorService) { - final Random random = new Random(); - List tasks = new ArrayList<>(); - for (final Cluster cluster : clusters) { - tasks.add(new Runnable() { - public void run() { - try { - Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size())); - clusterSet.addNewClusterWithCenter(point); - } catch (Throwable t) { - log.warn("Error Splitting clusters (2)", t); - } - } - }); - } - - MultiThreadUtils.parallelTasks(tasks, executorService); - } - - public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){ - val op = createDistanceFunctionOp(distanceFunction, x, y); - op.setDimensions(dimensions); - return op; - } - - public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){ - switch (distanceFunction){ - case COSINE_DISTANCE: - return new CosineDistance(x,y); - case COSINE_SIMILARITY: - return new CosineSimilarity(x,y); - case DOT: - return new Dot(x,y); - case EUCLIDEAN: - return new EuclideanDistance(x,y); - case JACCARD: - return new JaccardDistance(x,y); - case MANHATTAN: - return new ManhattanDistance(x,y); - default: - throw new IllegalStateException("Unknown distance function: " + distanceFunction); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java deleted file mode 100644 index 14147b004..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; - -/** - * - */ -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class Point implements Serializable { - - private static final long serialVersionUID = -6658028541426027226L; - - private String id = UUID.randomUUID().toString(); - private String label; - private INDArray array; - - - /** - * - * @param array - */ - public Point(INDArray array) { - super(); - this.array = array; - } - - /** - * - * @param id - * @param array - */ - public Point(String id, INDArray array) { - super(); - this.id = id; - this.array = array; - } - - public Point(String id, String label, double[] data) { - this(id, label, Nd4j.create(data)); - } - - public Point(String id, String label, INDArray array) { - super(); - this.id = id; - this.label = label; - this.array = array; - } - - - /** - * - * @param matrix - * @return - */ - public static List toPoints(INDArray matrix) { - List arr = new ArrayList<>(matrix.rows()); - for (int i = 0; i < matrix.rows(); i++) { - arr.add(new Point(matrix.getRow(i))); - } - - return arr; - } - - /** - * - * @param vectors - * @return - */ - public static List toPoints(List vectors) { - List points = new ArrayList<>(); - for (INDArray vector : vectors) - points.add(new Point(vector)); - return points; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java deleted file mode 100644 index 6951b4a03..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class PointClassification implements Serializable { - - private Cluster cluster; - private double distanceFromCenter; - private boolean newLocation; - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java deleted file mode 100644 index 852a58920..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -public interface ClusteringAlgorithmCondition { - - /** - * - * @param iterationHistory - * @return - */ - boolean isSatisfied(IterationHistory iterationHistory); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java deleted file mode 100644 index 6c2659f60..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.LessThan; - -import java.io.Serializable; - -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor(access = AccessLevel.PROTECTED) -public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition convergenceCondition; - private double pointsDistributionChangeRate; - - - /** - * - * @param pointsDistributionChangeRate - * @return - */ - public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) { - Condition condition = new LessThan(pointsDistributionChangeRate); - return new ConvergenceCondition(condition, pointsDistributionChangeRate); - } - - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - int iterationCount = iterationHistory.getIterationCount(); - if (iterationCount <= 1) - return false; - - double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get(); - variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount(); - - return convergenceCondition.apply(variation); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java deleted file mode 100644 index 7eda7a7ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual; - -import java.io.Serializable; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition iterationCountCondition; - - protected FixedIterationCountCondition(int initialClusterCount) { - iterationCountCondition = new GreaterThanOrEqual(initialClusterCount); - } - - /** - * - * @param iterationCount - * @return - */ - public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) { - return new FixedIterationCountCondition(iterationCount); - } - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount()); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java deleted file mode 100644 index ff91dd7eb..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.LessThan; - -import java.io.Serializable; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition varianceVariationCondition; - private int period; - - - - /** - * - * @param varianceVariation - * @param period - * @return - */ - public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) { - Condition condition = new LessThan(varianceVariation); - return new VarianceVariationCondition(condition, period); - } - - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - if (iterationHistory.getIterationCount() <= period) - return false; - - for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) { - double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - - if (!varianceVariationCondition.apply(variation)) - return false; - } - - return true; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java deleted file mode 100644 index 2b78ee3e8..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.info; - -import lombok.Data; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; - -/** - * - */ -@Data -public class ClusterInfo implements Serializable { - - private double averagePointDistanceFromCenter; - private double maxPointDistanceFromCenter; - private double pointDistanceFromCenterVariance; - private double totalPointDistanceFromCenter; - private boolean inverse; - private Map pointDistancesFromCenter = new ConcurrentHashMap<>(); - - public ClusterInfo(boolean inverse) { - this(false, inverse); - } - - /** - * - * @param threadSafe - */ - public ClusterInfo(boolean threadSafe, boolean inverse) { - super(); - this.inverse = inverse; - if (threadSafe) { - pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter); - } - } - - /** - * - * @return - */ - public Set> getSortedPointDistancesFromCenter() { - SortedSet> sortedEntries = new TreeSet<>(new Comparator>() { - @Override - public int compare(Map.Entry e1, Map.Entry e2) { - int res = e1.getValue().compareTo(e2.getValue()); - return res != 0 ? res : 1; - } - }); - sortedEntries.addAll(pointDistancesFromCenter.entrySet()); - return sortedEntries; - } - - /** - * - * @return - */ - public Set> getReverseSortedPointDistancesFromCenter() { - SortedSet> sortedEntries = new TreeSet<>(new Comparator>() { - @Override - public int compare(Map.Entry e1, Map.Entry e2) { - int res = e1.getValue().compareTo(e2.getValue()); - return -(res != 0 ? res : 1); - } - }); - sortedEntries.addAll(pointDistancesFromCenter.entrySet()); - return sortedEntries; - } - - /** - * - * @param maxDistance - * @return - */ - public List getPointsFartherFromCenterThan(double maxDistance) { - Set> sorted = getReverseSortedPointDistancesFromCenter(); - List ids = new ArrayList<>(); - for (Map.Entry entry : sorted) { - if (inverse && entry.getValue() < -maxDistance) { - if (entry.getValue() < -maxDistance) - break; - } - - else if (entry.getValue() > maxDistance) - break; - - ids.add(entry.getKey()); - } - return ids; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java deleted file mode 100644 index 3ddfd1b25..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.info; - -import org.nd4j.shade.guava.collect.HashBasedTable; -import org.nd4j.shade.guava.collect.Table; -import org.deeplearning4j.clustering.cluster.Cluster; -import org.deeplearning4j.clustering.cluster.ClusterSet; - -import java.io.Serializable; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; - -public class ClusterSetInfo implements Serializable { - - private Map clustersInfos = new HashMap<>(); - private Table distancesBetweenClustersCenters = HashBasedTable.create(); - private AtomicInteger pointLocationChange; - private boolean threadSafe; - private boolean inverse; - - public ClusterSetInfo(boolean inverse) { - this(inverse, false); - } - - /** - * - * @param inverse - * @param threadSafe - */ - public ClusterSetInfo(boolean inverse, boolean threadSafe) { - this.pointLocationChange = new AtomicInteger(0); - this.threadSafe = threadSafe; - this.inverse = inverse; - if (threadSafe) { - clustersInfos = Collections.synchronizedMap(clustersInfos); - } - } - - - /** - * - * @param clusterSet - * @param threadSafe - * @return - */ - public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) { - ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe); - for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++) - info.addClusterInfo(clusterSet.getClusters().get(i).getId()); - return info; - } - - public void removeClusterInfos(List clusters) { - for (Cluster cluster : clusters) { - clustersInfos.remove(cluster.getId()); - } - } - - public ClusterInfo addClusterInfo(String clusterId) { - ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe); - clustersInfos.put(clusterId, clusterInfo); - return clusterInfo; - } - - public ClusterInfo getClusterInfo(String clusterId) { - return clustersInfos.get(clusterId); - } - - public double getAveragePointDistanceFromClusterCenter() { - if (clustersInfos == null || clustersInfos.isEmpty()) - return 0; - - double average = 0; - for (ClusterInfo info : clustersInfos.values()) - average += info.getAveragePointDistanceFromCenter(); - return average / clustersInfos.size(); - } - - public double getPointDistanceFromClusterVariance() { - if (clustersInfos == null || clustersInfos.isEmpty()) - return 0; - - double average = 0; - for (ClusterInfo info : clustersInfos.values()) - average += info.getPointDistanceFromCenterVariance(); - return average / clustersInfos.size(); - } - - public int getPointsCount() { - int count = 0; - for (ClusterInfo clusterInfo : clustersInfos.values()) - count += clusterInfo.getPointDistancesFromCenter().size(); - return count; - } - - public Map getClustersInfos() { - return clustersInfos; - } - - public void setClustersInfos(Map clustersInfos) { - this.clustersInfos = clustersInfos; - } - - public Table getDistancesBetweenClustersCenters() { - return distancesBetweenClustersCenters; - } - - public void setDistancesBetweenClustersCenters(Table interClusterDistances) { - this.distancesBetweenClustersCenters = interClusterDistances; - } - - public AtomicInteger getPointLocationChange() { - return pointLocationChange; - } - - public void setPointLocationChange(AtomicInteger pointLocationChange) { - this.pointLocationChange = pointLocationChange; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java deleted file mode 100644 index 0854e5eb1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.iteration; - -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.clustering.info.ClusterSetInfo; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; - -public class IterationHistory implements Serializable { - @Getter - @Setter - private Map iterationsInfos = new HashMap<>(); - - /** - * - * @return - */ - public ClusterSetInfo getMostRecentClusterSetInfo() { - IterationInfo iterationInfo = getMostRecentIterationInfo(); - return iterationInfo == null ? null : iterationInfo.getClusterSetInfo(); - } - - /** - * - * @return - */ - public IterationInfo getMostRecentIterationInfo() { - return getIterationInfo(getIterationCount() - 1); - } - - /** - * - * @return - */ - public int getIterationCount() { - return getIterationsInfos().size(); - } - - /** - * - * @param iterationIdx - * @return - */ - public IterationInfo getIterationInfo(int iterationIdx) { - return getIterationsInfos().get(iterationIdx); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java deleted file mode 100644 index 0036f3c47..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.iteration; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.info.ClusterSetInfo; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class IterationInfo implements Serializable { - - private int index; - private ClusterSetInfo clusterSetInfo; - private boolean strategyApplied; - - public IterationInfo(int index) { - super(); - this.index = index; - } - - public IterationInfo(int index, ClusterSetInfo clusterSetInfo) { - super(); - this.index = index; - this.clusterSetInfo = clusterSetInfo; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java deleted file mode 100644 index c3e0bc418..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.custom.KnnMinDistance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; - -public class HyperRect implements Serializable { - - //private List points; - private float[] lowerEnds; - private float[] higherEnds; - private INDArray lowerEndsIND; - private INDArray higherEndsIND; - - public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { - this.lowerEnds = new float[lowerEndsIn.length]; - this.higherEnds = new float[lowerEndsIn.length]; - System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); - System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); - lowerEndsIND = Nd4j.createFromArray(lowerEnds); - higherEndsIND = Nd4j.createFromArray(higherEnds); - } - - public HyperRect(float[] point) { - this(point, point); - } - - public HyperRect(Pair ends) { - this(ends.getFirst(), ends.getSecond()); - } - - - public void enlargeTo(INDArray point) { - float[] pointAsArray = point.toFloatVector(); - for (int i = 0; i < lowerEnds.length; i++) { - float p = pointAsArray[i]; - if (lowerEnds[i] > p) - lowerEnds[i] = p; - else if (higherEnds[i] < p) - higherEnds[i] = p; - } - } - - public static Pair point(INDArray vector) { - Pair ret = new Pair<>(); - float[] curr = new float[(int)vector.length()]; - for (int i = 0; i < vector.length(); i++) { - curr[i] = vector.getFloat(i); - } - ret.setFirst(curr); - ret.setSecond(curr); - return ret; - } - - - /*public List contains(INDArray hPoint) { - List ret = new ArrayList<>(); - for (int i = 0; i < hPoint.length(); i++) { - ret.add(lowerEnds[i] <= hPoint.getDouble(i) && - higherEnds[i] >= hPoint.getDouble(i)); - } - return ret; - }*/ - - public double minDistance(INDArray hPoint, INDArray output) { - Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); - return output.getFloat(0); - - /*double ret = 0.0; - double[] pointAsArray = hPoint.toDoubleVector(); - for (int i = 0; i < pointAsArray.length; i++) { - double p = pointAsArray[i]; - if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { - if (p < lowerEnds[i]) - ret += Math.pow((p - lowerEnds[i]), 2); - else - ret += Math.pow((p - higherEnds[i]), 2); - } - } - ret = Math.pow(ret, 0.5); - return ret;*/ - } - - public HyperRect getUpper(INDArray hPoint, int desc) { - //Interval interval = points.get(desc); - float higher = higherEnds[desc]; - float d = hPoint.getFloat(desc); - if (higher < d) - return null; - HyperRect ret = new HyperRect(lowerEnds,higherEnds); - if (ret.lowerEnds[desc] < d) - ret.lowerEnds[desc] = d; - return ret; - } - - public HyperRect getLower(INDArray hPoint, int desc) { - //Interval interval = points.get(desc); - float lower = lowerEnds[desc]; - float d = hPoint.getFloat(desc); - if (lower > d) - return null; - HyperRect ret = new HyperRect(lowerEnds,higherEnds); - //Interval i2 = ret.points.get(desc); - if (ret.higherEnds[desc] > d) - ret.higherEnds[desc] = d; - return ret; - } - - @Override - public String toString() { - String retVal = ""; - retVal += "["; - for (int i = 0; i < lowerEnds.length; ++i) { - retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") "; - } - retVal += "]"; - return retVal; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java deleted file mode 100644 index fd77c8342..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java +++ /dev/null @@ -1,370 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; - -public class KDTree implements Serializable { - - private KDNode root; - private int dims = 100; - public final static int GREATER = 1; - public final static int LESS = 0; - private int size = 0; - private HyperRect rect; - - public KDTree(int dims) { - this.dims = dims; - } - - /** - * Insert a point in to the tree - * @param point the point to insert - */ - public void insert(INDArray point) { - if (!point.isVector() || point.length() != dims) - throw new IllegalArgumentException("Point must be a vector of length " + dims); - - if (root == null) { - root = new KDNode(point); - rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); - } else { - int disc = 0; - KDNode node = root; - KDNode insert = new KDNode(point); - int successor; - while (true) { - //exactly equal - INDArray pt = node.getPoint(); - INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z(); - if (countEq.getInt(0) == 0) { - return; - } else { - successor = successor(node, point, disc); - KDNode child; - if (successor < 1) - child = node.getLeft(); - else - child = node.getRight(); - if (child == null) - break; - disc = (disc + 1) % dims; - node = child; - } - } - - if (successor < 1) - node.setLeft(insert); - - else - node.setRight(insert); - - rect.enlargeTo(point); - insert.setParent(node); - } - size++; - - } - - - public INDArray delete(INDArray point) { - KDNode node = root; - int _disc = 0; - while (node != null) { - if (node.point == point) - break; - int successor = successor(node, point, _disc); - if (successor < 1) - node = node.getLeft(); - else - node = node.getRight(); - _disc = (_disc + 1) % dims; - } - - if (node != null) { - if (node == root) { - root = delete(root, _disc); - } else - node = delete(node, _disc); - size--; - if (size == 1) { - rect = new HyperRect(HyperRect.point(point)); - } else if (size == 0) - rect = null; - - } - return node.getPoint(); - } - - // Share this data for recursive calls of "knn" - private float currentDistance; - private INDArray currentPoint; - private INDArray minDistance = Nd4j.scalar(0.f); - - - public List> knn(INDArray point, float distance) { - List> best = new ArrayList<>(); - currentDistance = distance; - currentPoint = point; - knn(root, rect, best, 0); - Collections.sort(best, new Comparator>() { - @Override - public int compare(Pair o1, Pair o2) { - return Float.compare(o1.getKey(), o2.getKey()); - } - }); - - return best; - } - - - private void knn(KDNode node, HyperRect rect, List> best, int _disc) { - if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) - return; - int _discNext = (_disc + 1) % dims; - float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() - .floatValue(); - - if (distance <= currentDistance) { - best.add(Pair.of(distance, node.getPoint())); - } - - HyperRect lower = rect.getLower(node.point, _disc); - HyperRect upper = rect.getUpper(node.point, _disc); - knn(node.getLeft(), lower, best, _discNext); - knn(node.getRight(), upper, best, _discNext); - } - - /** - * Query for nearest neighbor. Returns the distance and point - * @param point the point to query for - * @return - */ - public Pair nn(INDArray point) { - return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0); - } - - - private Pair nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, - int _disc) { - if (node == null || rect.minDistance(point, minDistance) > dist) - return Pair.of(Double.POSITIVE_INFINITY, null); - - int _discNext = (_disc + 1) % dims; - double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue(); - if (dist2 < dist) { - best = node.getPoint(); - dist = dist2; - } - - HyperRect lower = rect.getLower(node.point, _disc); - HyperRect upper = rect.getUpper(node.point, _disc); - - if (point.getDouble(_disc) < node.point.getDouble(_disc)) { - Pair left = nn(node.getLeft(), point, lower, dist, best, _discNext); - Pair right = nn(node.getRight(), point, upper, dist, best, _discNext); - if (left.getKey() < dist) - return left; - else if (right.getKey() < dist) - return right; - - } else { - Pair left = nn(node.getRight(), point, upper, dist, best, _discNext); - Pair right = nn(node.getLeft(), point, lower, dist, best, _discNext); - if (left.getKey() < dist) - return left; - else if (right.getKey() < dist) - return right; - } - - return Pair.of(dist, best); - - } - - private KDNode delete(KDNode delete, int _disc) { - if (delete.getLeft() != null && delete.getRight() != null) { - if (delete.getParent() != null) { - if (delete.getParent().getLeft() == delete) - delete.getParent().setLeft(null); - else - delete.getParent().setRight(null); - - } - return null; - } - - int disc = _disc; - _disc = (_disc + 1) % dims; - Pair qd = null; - if (delete.getRight() != null) { - qd = min(delete.getRight(), disc, _disc); - } else if (delete.getLeft() != null) - qd = max(delete.getLeft(), disc, _disc); - if (qd == null) {// is leaf - return null; - } - delete.point = qd.getKey().point; - KDNode qFather = qd.getKey().getParent(); - if (qFather.getLeft() == qd.getKey()) { - qFather.setLeft(delete(qd.getKey(), disc)); - } else if (qFather.getRight() == qd.getKey()) { - qFather.setRight(delete(qd.getKey(), disc)); - - } - - return delete; - - - } - - - private Pair max(KDNode node, int disc, int _disc) { - int discNext = (_disc + 1) % dims; - if (_disc == disc) { - KDNode child = node.getLeft(); - if (child != null) { - return max(child, disc, discNext); - } - } else if (node.getLeft() != null || node.getRight() != null) { - Pair left = null, right = null; - if (node.getLeft() != null) - left = max(node.getLeft(), disc, discNext); - if (node.getRight() != null) - right = max(node.getRight(), disc, discNext); - if (left != null && right != null) { - double pointLeft = left.getKey().getPoint().getDouble(disc); - double pointRight = right.getKey().getPoint().getDouble(disc); - if (pointLeft > pointRight) - return left; - else - return right; - } else if (left != null) - return left; - else - return right; - } - - return Pair.of(node, _disc); - } - - - - private Pair min(KDNode node, int disc, int _disc) { - int discNext = (_disc + 1) % dims; - if (_disc == disc) { - KDNode child = node.getLeft(); - if (child != null) { - return min(child, disc, discNext); - } - } else if (node.getLeft() != null || node.getRight() != null) { - Pair left = null, right = null; - if (node.getLeft() != null) - left = min(node.getLeft(), disc, discNext); - if (node.getRight() != null) - right = min(node.getRight(), disc, discNext); - if (left != null && right != null) { - double pointLeft = left.getKey().getPoint().getDouble(disc); - double pointRight = right.getKey().getPoint().getDouble(disc); - if (pointLeft < pointRight) - return left; - else - return right; - } else if (left != null) - return left; - else - return right; - } - - return Pair.of(node, _disc); - } - - /** - * The number of elements in the tree - * @return the number of elements in the tree - */ - public int size() { - return size; - } - - private int successor(KDNode node, INDArray point, int disc) { - for (int i = disc; i < dims; i++) { - double pointI = point.getDouble(i); - double nodePointI = node.getPoint().getDouble(i); - if (pointI < nodePointI) - return LESS; - else if (pointI > nodePointI) - return GREATER; - - } - - throw new IllegalStateException("Point is equal!"); - } - - - private static class KDNode { - private INDArray point; - private KDNode left, right, parent; - - public KDNode(INDArray point) { - this.point = point; - } - - public INDArray getPoint() { - return point; - } - - public KDNode getLeft() { - return left; - } - - public void setLeft(KDNode left) { - this.left = left; - } - - public KDNode getRight() { - return right; - } - - public void setRight(KDNode right) { - this.right = right; - } - - public KDNode getParent() { - return parent; - } - - public void setParent(KDNode parent) { - this.parent = parent; - } - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java deleted file mode 100755 index 00b5bb3e9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kmeans; - -import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.strategy.ClusteringStrategy; -import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy; - - -public class KMeansClustering extends BaseClusteringAlgorithm { - - private static final long serialVersionUID = 8476951388145944776L; - private static final double VARIATION_TOLERANCE= 1e-4; - - - /** - * - * @param clusteringStrategy - */ - protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) { - super(clusteringStrategy, useKMeansPlusPlus); - } - - /** - * Setup a kmeans instance - * @param clusterCount the number of clusters - * @param maxIterationCount the max number of iterations - * to run kmeans - * @param distanceFunction the distance function to use for grouping - * @return - */ - public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, - boolean inverse, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = - FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse); - clusteringStrategy.endWhenIterationCountEquals(maxIterationCount); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - /** - * - * @param clusterCount - * @param minDistributionVariationRate - * @param distanceFunction - * @param allowEmptyClusters - * @return - */ - public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, - boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse) - .endWhenDistributionVariationRateLessThan(minDistributionVariationRate); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - - /** - * Setup a kmeans instance - * @param clusterCount the number of clusters - * @param maxIterationCount the max number of iterations - * to run kmeans - * @param distanceFunction the distance function to use for grouping - * @return - */ - public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) { - return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus); - } - - /** - * - * @param clusterCount - * @param minDistributionVariationRate - * @param distanceFunction - * @param allowEmptyClusters - * @return - */ - public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, - boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); - clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - public static KMeansClustering setup(int clusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); - clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java deleted file mode 100644 index b9fbffa7a..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import org.nd4j.linalg.api.ndarray.INDArray; - -public interface LSH { - - /** - * Returns an instance of the distance measure associated to the LSH family of this implementation. - * Beware, hashing families and their amplification constructs are distance-specific. - */ - String getDistanceMeasure(); - - /** - * Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction - * - * denoting hashLength by h, - * amplifies a (d1, d2, p1, p2) hash family into a - * (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h) - * - * @return the length of the hash in the AND construction used by this index - */ - int getHashLength(); - - /** - * - * denoting numTables by n, - * amplifies a (d1, d2, p1, p2) hash family into a - * (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n) - * - * @return the # of hash tables in the OR construction used by this index - */ - int getNumTables(); - - /** - * @return The dimension of the index vectors and queries - */ - int getInDimension(); - - /** - * Populates the index with data vectors. - * @param data the vectors to index - */ - void makeIndex(INDArray data); - - /** - * Returns the set of all vectors that could approximately be considered negihbors of the query, - * without selection on the basis of distance or number of neighbors. - * @param query a vector to find neighbors for - * @return its approximate neighbors, unfiltered - */ - INDArray bucket(INDArray query); - - /** - * Returns the approximate neighbors within a distance bound. - * @param query a vector to find neighbors for - * @param maxRange the maximum distance between results and the query - * @return approximate neighbors within the distance bounds - */ - INDArray search(INDArray query, double maxRange); - - /** - * Returns the approximate neighbors within a k-closest bound - * @param query a vector to find neighbors for - * @param k the maximum number of closest neighbors to return - * @return at most k neighbors of the query, ordered by increasing distance - */ - INDArray search(INDArray query, int k); -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java deleted file mode 100644 index 7b9873d73..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import lombok.Getter; -import lombok.val; -import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; -import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; -import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.ops.transforms.Transforms; - -import java.util.Arrays; - - -public class RandomProjectionLSH implements LSH { - - @Override - public String getDistanceMeasure(){ - return "cosinedistance"; - } - - @Getter private int hashLength; - - @Getter private int numTables; - - @Getter private int inDimension; - - - @Getter private double radius; - - INDArray randomProjection; - - INDArray index; - - INDArray indexData; - - - private INDArray gaussianRandomMatrix(int[] shape, Random rng){ - INDArray res = Nd4j.create(shape); - - GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0])); - - Nd4j.getExecutioner().exec(op1, rng); - return res; - } - - public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){ - this(hashLength, numTables, inDimension, radius, Nd4j.getRandom()); - } - - /** - * Creates a locality-sensitive hashing index for the cosine distance, - * a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification - * - * @param hashLength the length of the compared hash in an AND construction, - * @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple - * probes of Panigraphi (op. cit). - * @param inDimension the dimendionality of the points being indexed - * @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction - * @param rng a Random object to draw samples from - */ - public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){ - this.hashLength = hashLength; - this.numTables = numTables; - this.inDimension = inDimension; - this.radius = radius; - randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng); - } - - /** - * This picks uniformaly distributed random points on the unit of a sphere using the method of: - * - * An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere - * JS Hicks, RF Wheeling - Communications of the ACM, 1959 - * @param data a query to generate multiple probes for - * @return `numTables` - */ - public INDArray entropy(INDArray data){ - - INDArray data2 = - Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius)); - - INDArray norms = Nd4j.norm2(data2.dup(), -1); - - Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms); - - data2.diviColumnVector(norms); - data2.addiRowVector(data); - return data2; - } - - /** - * Returns hash values for a particular query - * @param data a query vector - * @return its hashed value - */ - public INDArray hash(INDArray data) { - if (data.shape()[1] != inDimension){ - throw new ND4JIllegalStateException( - String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d", - Arrays.toString(data.shape()), inDimension)); - } - INDArray projected = data.mmul(randomProjection); - INDArray res = Nd4j.getExecutioner().exec(new Sign(projected)); - return res; - } - - /** - * Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it. - * @param data the vectors to index - */ - @Override - public void makeIndex(INDArray data) { - index = hash(data); - indexData = data; - } - - // data elements in the same bucket as the query, without entropy - INDArray rawBucketOf(INDArray query){ - INDArray pattern = hash(query); - - INDArray res = Nd4j.zeros(DataType.BOOL, index.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1)); - return res.castTo(Nd4j.defaultFloatingPointType()).min(-1); - } - - @Override - public INDArray bucket(INDArray query) { - INDArray queryRes = rawBucketOf(query); - - if(numTables > 1) { - INDArray entropyQueries = entropy(query); - - // loop, addi + conditionalreplace -> poor man's OR function - for (int i = 0; i < numTables; i++) { - INDArray row = entropyQueries.getRow(i, true); - queryRes.addi(rawBucketOf(row)); - } - BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0)); - } - - return queryRes; - } - - // data elements in the same entropy bucket as the query, - INDArray bucketData(INDArray query){ - INDArray mask = bucket(query); - int nRes = mask.sum(0).getInt(0); - INDArray res = Nd4j.create(new int[] {nRes, inDimension}); - int j = 0; - for (int i = 0; i < nRes; i++){ - while (mask.getInt(j) == 0 && j < mask.length() - 1) { - j += 1; - } - if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j)); - j += 1; - } - return res; - } - - @Override - public INDArray search(INDArray query, double maxRange) { - if (maxRange < 0) - throw new IllegalArgumentException("ANN search should have a positive maximum search radius"); - - INDArray bucketData = bucketData(query); - INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); - INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); - - INDArray shuffleIndexes = idxs[0]; - INDArray sortedDistances = idxs[1]; - int accepted = 0; - while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1; - - INDArray res = Nd4j.create(new int[] {accepted, inDimension}); - for(int i = 0; i < accepted; i++){ - res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); - } - return res; - } - - @Override - public INDArray search(INDArray query, int k) { - if (k < 1) - throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor"); - - INDArray bucketData = bucketData(query); - INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); - INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); - - INDArray shuffleIndexes = idxs[0]; - INDArray sortedDistances = idxs[1]; - val accepted = Math.min(k, sortedDistances.shape()[1]); - - INDArray res = Nd4j.create(accepted, inDimension); - for(int i = 0; i < accepted; i++){ - res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); - } - return res; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java deleted file mode 100644 index b65571de3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.optimisation; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class ClusteringOptimization implements Serializable { - - private ClusteringOptimizationType type; - private double value; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java deleted file mode 100644 index a2220010e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.optimisation; - -/** - * - */ -public enum ClusteringOptimizationType { - MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java deleted file mode 100644 index cb82b6f87..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -public class Cell implements Serializable { - private double x, y, hw, hh; - - public Cell(double x, double y, double hw, double hh) { - this.x = x; - this.y = y; - this.hw = hw; - this.hh = hh; - } - - /** - * Whether the given point is contained - * within this cell - * @param point the point to check - * @return true if the point is contained, false otherwise - */ - public boolean containsPoint(INDArray point) { - double first = point.getDouble(0), second = point.getDouble(1); - return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second; - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof Cell)) - return false; - - Cell cell = (Cell) o; - - if (Double.compare(cell.hh, hh) != 0) - return false; - if (Double.compare(cell.hw, hw) != 0) - return false; - if (Double.compare(cell.x, x) != 0) - return false; - return Double.compare(cell.y, y) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - temp = Double.doubleToLongBits(x); - result = (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(y); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(hw); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(hh); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - public double getX() { - return x; - } - - public void setX(double x) { - this.x = x; - } - - public double getY() { - return y; - } - - public void setY(double y) { - this.y = y; - } - - public double getHw() { - return hw; - } - - public void setHw(double hw) { - this.hw = hw; - } - - public double getHh() { - return hh; - } - - public void setHh(double hh) { - this.hh = hh; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java deleted file mode 100644 index 20d216b44..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java +++ /dev/null @@ -1,383 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; -import org.apache.commons.math3.util.FastMath; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; - -import static java.lang.Math.max; - -public class QuadTree implements Serializable { - private QuadTree parent, northWest, northEast, southWest, southEast; - private boolean isLeaf = true; - private int size, cumSize; - private Cell boundary; - static final int QT_NO_DIMS = 2; - static final int QT_NODE_CAPACITY = 1; - private INDArray buf = Nd4j.create(QT_NO_DIMS); - private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS); - private int[] index = new int[QT_NODE_CAPACITY]; - - - /** - * Pass in a matrix - * @param data - */ - public QuadTree(INDArray data) { - INDArray meanY = data.mean(0); - INDArray minY = data.min(0); - INDArray maxY = data.max(0); - init(data, meanY.getDouble(0), meanY.getDouble(1), - max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0)) - + Nd4j.EPS_THRESHOLD, - max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1)) - + Nd4j.EPS_THRESHOLD); - fill(); - } - - public QuadTree(QuadTree parent, INDArray data, Cell boundary) { - this.parent = parent; - this.boundary = boundary; - this.data = data; - - } - - public QuadTree(Cell boundary) { - this.boundary = boundary; - } - - private void init(INDArray data, double x, double y, double hw, double hh) { - boundary = new Cell(x, y, hw, hh); - this.data = data; - } - - private void fill() { - for (int i = 0; i < data.rows(); i++) - insert(i); - } - - - - /** - * Returns the cell of this element - * - * @param coordinates - * @return - */ - protected QuadTree findIndex(INDArray coordinates) { - - // Compute the sector for the coordinates - boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2)); - boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2)); - - // top left - QuadTree index = getNorthWest(); - if (left) { - // left side - if (!top) { - // bottom left - index = getSouthWest(); - } - } else { - // right side - if (top) { - // top right - index = getNorthEast(); - } else { - // bottom right - index = getSouthEast(); - - } - } - - return index; - } - - - /** - * Insert an index of the data in to the tree - * @param newIndex the index to insert in to the tree - * @return whether the index was inserted or not - */ - public boolean insert(int newIndex) { - // Ignore objects which do not belong in this quad tree - INDArray point = data.slice(newIndex); - if (!boundary.containsPoint(point)) - return false; - - cumSize++; - double mult1 = (double) (cumSize - 1) / (double) cumSize; - double mult2 = 1.0 / (double) cumSize; - - centerOfMass.muli(mult1); - centerOfMass.addi(point.mul(mult2)); - - // If there is space in this quad tree and it is a leaf, add the object here - if (isLeaf() && size < QT_NODE_CAPACITY) { - index[size] = newIndex; - size++; - return true; - } - - //duplicate point - if (size > 0) { - for (int i = 0; i < size; i++) { - INDArray compPoint = data.slice(index[i]); - if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1)) - return true; - } - } - - - - // If this Node has already been subdivided just add the elements to the - // appropriate cell - if (!isLeaf()) { - QuadTree index = findIndex(point); - index.insert(newIndex); - return true; - } - - if (isLeaf()) - subDivide(); - - boolean ret = insertIntoOneOf(newIndex); - - - - return ret; - } - - private boolean insertIntoOneOf(int index) { - boolean success = false; - success = northWest.insert(index); - if (!success) - success = northEast.insert(index); - if (!success) - success = southWest.insert(index); - if (!success) - success = southEast.insert(index); - return success; - } - - - /** - * Returns whether the tree is consistent or not - * @return whether the tree is consistent or not - */ - public boolean isCorrect() { - - for (int n = 0; n < size; n++) { - INDArray point = data.slice(index[n]); - if (!boundary.containsPoint(point)) - return false; - } - - return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect() - && southEast.isCorrect(); - - } - - - - /** - * Create four children - * which fully divide this cell - * into four quads of equal area - */ - public void subDivide() { - northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), - boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), - boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), - boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), - boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - - } - - - /** - * Compute non edge forces using barnes hut - * @param pointIndex - * @param theta - * @param negativeForce - * @param sumQ - */ - public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { - // Make sure that we spend no time on empty nodes or self-interactions - if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) - return; - - - // Compute distance between point and center-of-mass - buf.assign(data.slice(pointIndex)).subi(centerOfMass); - - double D = Nd4j.getBlasWrapper().dot(buf, buf); - - // Check whether we can use this node as a "summary" - if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) { - - // Compute and add t-SNE force between point and current node - double Q = 1.0 / (1.0 + D); - double mult = cumSize * Q; - sumQ.addAndGet(mult); - mult *= Q; - negativeForce.addi(buf.mul(mult)); - - } else { - - // Recursively apply Barnes-Hut to children - northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - } - } - - - - /** - * - * @param rowP a vector - * @param colP - * @param valP - * @param N - * @param posF - */ - public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { - if (!rowP.isVector()) - throw new IllegalArgumentException("RowP must be a vector"); - - // Loop over all edges in the graph - double D; - for (int n = 0; n < N; n++) { - for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { - - // Compute pairwise distance and Q-value - buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i))); - - D = Nd4j.getBlasWrapper().dot(buf, buf); - D = valP.getDouble(i) / D; - - // Sum positive force - posF.slice(n).addi(buf.mul(D)); - - } - } - } - - - /** - * The depth of the node - * @return the depth of the node - */ - public int depth() { - if (isLeaf()) - return 1; - return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth())); - } - - public INDArray getCenterOfMass() { - return centerOfMass; - } - - public void setCenterOfMass(INDArray centerOfMass) { - this.centerOfMass = centerOfMass; - } - - public QuadTree getParent() { - return parent; - } - - public void setParent(QuadTree parent) { - this.parent = parent; - } - - public QuadTree getNorthWest() { - return northWest; - } - - public void setNorthWest(QuadTree northWest) { - this.northWest = northWest; - } - - public QuadTree getNorthEast() { - return northEast; - } - - public void setNorthEast(QuadTree northEast) { - this.northEast = northEast; - } - - public QuadTree getSouthWest() { - return southWest; - } - - public void setSouthWest(QuadTree southWest) { - this.southWest = southWest; - } - - public QuadTree getSouthEast() { - return southEast; - } - - public void setSouthEast(QuadTree southEast) { - this.southEast = southEast; - } - - public boolean isLeaf() { - return isLeaf; - } - - public void setLeaf(boolean isLeaf) { - this.isLeaf = isLeaf; - } - - public int getSize() { - return size; - } - - public void setSize(int size) { - this.size = size; - } - - public int getCumSize() { - return cumSize; - } - - public void setCumSize(int cumSize) { - this.cumSize = cumSize; - } - - public Cell getBoundary() { - return boundary; - } - - public void setBoundary(Cell boundary) { - this.boundary = boundary; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java deleted file mode 100644 index f814025d5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.ArrayList; -import java.util.List; - -/** - * - */ -@Data -public class RPForest { - - private int numTrees; - private List trees; - private INDArray data; - private int maxSize = 1000; - private String similarityFunction; - - /** - * Create the rp forest with the specified number of trees - * @param numTrees the number of trees in the forest - * @param maxSize the max size of each tree - * @param similarityFunction the distance function to use - */ - public RPForest(int numTrees,int maxSize,String similarityFunction) { - this.numTrees = numTrees; - this.maxSize = maxSize; - this.similarityFunction = similarityFunction; - trees = new ArrayList<>(numTrees); - - } - - - /** - * Build the trees from the given dataset - * @param x the input dataset (should be a 2d matrix) - */ - public void fit(INDArray x) { - this.data = x; - for(int i = 0; i < numTrees; i++) { - RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction); - tree.buildTree(x); - trees.add(tree); - } - } - - /** - * Get all candidates relative to a specific datapoint. - * @param input - * @return - */ - public INDArray getAllCandidates(INDArray input) { - return RPUtils.getAllCandidates(input,trees,similarityFunction); - } - - /** - * Query results up to length n - * nearest neighbors - * @param toQuery the query item - * @param n the number of nearest neighbors for the given data point - * @return the indices for the nearest neighbors - */ - public INDArray queryAll(INDArray toQuery,int n) { - return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction); - } - - - /** - * Query all with the distances - * sorted by index - * @param query the query vector - * @param numResults the number of results to return - * @return a list of samples - */ - public List> queryWithDistances(INDArray query, int numResults) { - return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java deleted file mode 100644 index 979013797..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -@Data -public class RPHyperPlanes { - private int dim; - private INDArray wholeHyperPlane; - - public RPHyperPlanes(int dim) { - this.dim = dim; - } - - public INDArray getHyperPlaneAt(int depth) { - if(wholeHyperPlane.isVector()) - return wholeHyperPlane; - return wholeHyperPlane.slice(depth); - } - - - /** - * Add a new random element to the hyper plane. - */ - public void addRandomHyperPlane() { - INDArray newPlane = Nd4j.randn(new int[] {1,dim}); - newPlane.divi(newPlane.normmaxNumber()); - if(wholeHyperPlane == null) - wholeHyperPlane = newPlane; - else { - wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane); - } - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java deleted file mode 100644 index 9a103469e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - - -import lombok.Data; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Future; - -@Data -public class RPNode { - private int depth; - private RPNode left,right; - private Future leftFuture,rightFuture; - private List indices; - private double median; - private RPTree tree; - - - public RPNode(RPTree tree,int depth) { - this.depth = depth; - this.tree = tree; - indices = new ArrayList<>(); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java deleted file mode 100644 index 7fbca2b90..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ExecutorService; - -@Data -public class RPTree { - private RPNode root; - private RPHyperPlanes rpHyperPlanes; - private int dim; - //also knows as leave size - private int maxSize; - private INDArray X; - private String similarityFunction = "euclidean"; - private WorkspaceConfiguration workspaceConfiguration; - private ExecutorService searchExecutor; - private int searchWorkers; - - /** - * - * @param dim the dimension of the vectors - * @param maxSize the max size of the leaves - * - */ - @Builder - public RPTree(int dim, int maxSize,String similarityFunction) { - this.dim = dim; - this.maxSize = maxSize; - rpHyperPlanes = new RPHyperPlanes(dim); - root = new RPNode(this,0); - this.similarityFunction = similarityFunction; - workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) - .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE).build(); - - } - - /** - * - * @param dim the dimension of the vectors - * @param maxSize the max size of the leaves - * - */ - public RPTree(int dim, int maxSize) { - this(dim,maxSize,"euclidean"); - } - - /** - * Build the tree with the given input data - * @param x - */ - - public void buildTree(INDArray x) { - this.X = x; - for(int i = 0; i < x.rows(); i++) { - root.getIndices().add(i); - } - - - - RPUtils.buildTree(this,root,rpHyperPlanes, - x,maxSize,0,similarityFunction); - } - - - - public void addNodeAtIndex(int idx,INDArray toAdd) { - RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction); - query.getIndices().add(idx); - } - - - public List getLeaves() { - List nodes = new ArrayList<>(); - RPUtils.scanForLeaves(nodes,getRoot()); - return nodes; - } - - - /** - * Query all with the distances - * sorted by index - * @param query the query vector - * @param numResults the number of results to return - * @return a list of samples - */ - public List> queryWithDistances(INDArray query, int numResults) { - return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction); - } - - public INDArray query(INDArray query,int numResults) { - return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction); - } - - public List getCandidates(INDArray target) { - return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction); - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java deleted file mode 100644 index 0bd2574e7..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java +++ /dev/null @@ -1,481 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.nd4j.shade.guava.primitives.Doubles; -import lombok.val; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.util.*; - -public class RPUtils { - - - private static ThreadLocal> functionInstances = new ThreadLocal<>(); - - public static DifferentialFunction getOp(String name, - INDArray x, - INDArray y, - INDArray result) { - Map ops = functionInstances.get(); - if(ops == null) { - ops = new HashMap<>(); - functionInstances.set(ops); - } - - boolean allDistances = x.length() != y.length(); - - switch(name) { - case "cosinedistance": - if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances); - ops.put(name,cosineDistance); - return cosineDistance; - } - else { - CosineDistance cosineDistance = (CosineDistance) ops.get(name); - return cosineDistance; - } - case "cosinesimilarity": - if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) { - CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances); - ops.put(name,cosineSimilarity); - return cosineSimilarity; - } - else { - CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name); - cosineSimilarity.setX(x); - cosineSimilarity.setY(y); - cosineSimilarity.setZ(result); - return cosineSimilarity; - - } - case "manhattan": - if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances); - ops.put(name,manhattanDistance); - return manhattanDistance; - } - else { - ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name); - manhattanDistance.setX(x); - manhattanDistance.setY(y); - manhattanDistance.setZ(result); - return manhattanDistance; - } - case "jaccard": - if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances); - ops.put(name,jaccardDistance); - return jaccardDistance; - } - else { - JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name); - jaccardDistance.setX(x); - jaccardDistance.setY(y); - jaccardDistance.setZ(result); - return jaccardDistance; - } - case "hamming": - if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances); - ops.put(name,hammingDistance); - return hammingDistance; - } - else { - HammingDistance hammingDistance = (HammingDistance) ops.get(name); - hammingDistance.setX(x); - hammingDistance.setY(y); - hammingDistance.setZ(result); - return hammingDistance; - } - //euclidean - default: - if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances); - ops.put(name,euclideanDistance); - return euclideanDistance; - } - else { - EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name); - euclideanDistance.setX(x); - euclideanDistance.setY(y); - euclideanDistance.setZ(result); - return euclideanDistance; - } - } - } - - - /** - * Query all trees using the given input and data - * @param toQuery the query vector - * @param X the input data to query - * @param trees the trees to query - * @param n the number of results to search for - * @param similarityFunction the similarity function to use - * @return the indices (in order) in the ndarray - */ - public static List> queryAllWithDistances(INDArray toQuery,INDArray X,List trees,int n,String similarityFunction) { - if(trees.isEmpty()) { - throw new ND4JIllegalArgumentException("Trees is empty!"); - } - - List candidates = getCandidates(toQuery, trees,similarityFunction); - val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); - int numReturns = Math.min(n,sortedCandidates.size()); - List> ret = new ArrayList<>(numReturns); - for(int i = 0; i < numReturns; i++) { - ret.add(sortedCandidates.get(i)); - } - - return ret; - } - - /** - * Query all trees using the given input and data - * @param toQuery the query vector - * @param X the input data to query - * @param trees the trees to query - * @param n the number of results to search for - * @param similarityFunction the similarity function to use - * @return the indices (in order) in the ndarray - */ - public static INDArray queryAll(INDArray toQuery,INDArray X,List trees,int n,String similarityFunction) { - if(trees.isEmpty()) { - throw new ND4JIllegalArgumentException("Trees is empty!"); - } - - List candidates = getCandidates(toQuery, trees,similarityFunction); - val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); - int numReturns = Math.min(n,sortedCandidates.size()); - - INDArray result = Nd4j.create(numReturns); - for(int i = 0; i < numReturns; i++) { - result.putScalar(i,sortedCandidates.get(i).getSecond()); - } - - - return result; - } - - /** - * Get the sorted distances given the - * query vector, input data, given the list of possible search candidates - * @param x the query vector - * @param X the input data to use - * @param candidates the possible search candidates - * @param similarityFunction the similarity function to use - * @return the sorted distances - */ - public static List> sortCandidates(INDArray x,INDArray X, - List candidates, - String similarityFunction) { - int prevIdx = -1; - List> ret = new ArrayList<>(); - for(int i = 0; i < candidates.size(); i++) { - if(candidates.get(i) != prevIdx) { - ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i))); - } - - prevIdx = i; - } - - - Collections.sort(ret, new Comparator>() { - @Override - public int compare(Pair doubleIntegerPair, Pair t1) { - return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst()); - } - }); - - return ret; - } - - - - /** - * Get the search candidates as indices given the input - * and similarity function - * @param x the input data to search with - * @param trees the trees to search - * @param similarityFunction the function to use for similarity - * @return the list of indices as the search results - */ - public static INDArray getAllCandidates(INDArray x,List trees,String similarityFunction) { - List candidates = getCandidates(x,trees,similarityFunction); - Collections.sort(candidates); - - int prevIdx = -1; - int idxCount = 0; - List> scores = new ArrayList<>(); - for(int i = 0; i < candidates.size(); i++) { - if(candidates.get(i) == prevIdx) { - idxCount++; - } - else if(prevIdx != -1) { - scores.add(Pair.of(idxCount,prevIdx)); - idxCount = 1; - } - - prevIdx = i; - } - - - scores.add(Pair.of(idxCount,prevIdx)); - - INDArray arr = Nd4j.create(scores.size()); - for(int i = 0; i < scores.size(); i++) { - arr.putScalar(i,scores.get(i).getSecond()); - } - - return arr; - } - - - /** - * Get the search candidates as indices given the input - * and similarity function - * @param x the input data to search with - * @param roots the trees to search - * @param similarityFunction the function to use for similarity - * @return the list of indices as the search results - */ - public static List getCandidates(INDArray x,List roots,String similarityFunction) { - Set ret = new LinkedHashSet<>(); - for(RPTree tree : roots) { - RPNode root = tree.getRoot(); - RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction); - ret.addAll(query.getIndices()); - } - - return new ArrayList<>(ret); - } - - - /** - * Query the tree starting from the given node - * using the given hyper plane and similarity function - * @param from the node to start from - * @param planes the hyper plane to query - * @param x the input data - * @param similarityFunction the similarity function to use - * @return the leaf node representing the given query from a - * search in the tree - */ - public static RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) { - if(from.getLeft() == null && from.getRight() == null) { - return from; - } - - INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth()); - double dist = computeDistance(similarityFunction,x,hyperPlane); - if(dist <= from.getMedian()) { - return query(from.getLeft(),planes,x,similarityFunction); - } - - else { - return query(from.getRight(),planes,x,similarityFunction); - } - - } - - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) { - ReduceOp op = (ReduceOp) getOp(function, x, y, result); - op.setDimensions(1); - Nd4j.getExecutioner().exec(op); - return op.z(); - } - - /** - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) { - ReduceOp op = (ReduceOp) getOp(function, x, y, result); - Nd4j.getExecutioner().exec(op); - return op.z().getDouble(0); - } - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static double computeDistance(String function,INDArray x,INDArray y) { - return computeDistance(function,x,y,Nd4j.scalar(0.0)); - } - - /** - * Initialize the tree given the input parameters - * @param tree the tree to initialize - * @param from the starting node - * @param planes the hyper planes to use (vector space for similarity) - * @param X the input data - * @param maxSize the max number of indices on a given leaf node - * @param depth the current depth of the tree - * @param similarityFunction the similarity function to use - */ - public static void buildTree(RPTree tree, - RPNode from, - RPHyperPlanes planes, - INDArray X, - int maxSize, - int depth, - String similarityFunction) { - if(from.getIndices().size() <= maxSize) { - //slimNode - slimNode(from); - return; - } - - - List distances = new ArrayList<>(); - RPNode left = new RPNode(tree,depth + 1); - RPNode right = new RPNode(tree,depth + 1); - - if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) { - planes.addRandomHyperPlane(); - } - - - INDArray hyperPlane = planes.getHyperPlaneAt(depth); - - - - for(int i = 0; i < from.getIndices().size(); i++) { - double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); - distances.add(cosineSim); - } - - Collections.sort(distances); - from.setMedian(distances.get(distances.size() / 2)); - - - for(int i = 0; i < from.getIndices().size(); i++) { - double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); - if(cosineSim <= from.getMedian()) { - left.getIndices().add(from.getIndices().get(i)); - } - else { - right.getIndices().add(from.getIndices().get(i)); - } - } - - //failed split - if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) { - slimNode(from); - return; - } - - - from.setLeft(left); - from.setRight(right); - slimNode(from); - - - buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction); - buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction); - - } - - - /** - * Scan for leaves accumulating - * the nodes in the passed in list - * @param nodes the nodes so far - * @param scan the tree to scan - */ - public static void scanForLeaves(List nodes,RPTree scan) { - scanForLeaves(nodes,scan.getRoot()); - } - - /** - * Scan for leaves accumulating - * the nodes in the passed in list - * @param nodes the nodes so far - */ - public static void scanForLeaves(List nodes,RPNode current) { - if(current.getLeft() == null && current.getRight() == null) - nodes.add(current); - if(current.getLeft() != null) - scanForLeaves(nodes,current.getLeft()); - if(current.getRight() != null) - scanForLeaves(nodes,current.getRight()); - } - - - /** - * Prune indices from the given node - * when it's a leaf - * @param node the node to prune - */ - public static void slimNode(RPNode node) { - if(node.getRight() != null && node.getLeft() != null) { - node.getIndices().clear(); - } - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java deleted file mode 100644 index c89e72ab1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class Cell implements Serializable { - private int dimension; - private INDArray corner, width; - - public Cell(int dimension) { - this.dimension = dimension; - } - - public double corner(int d) { - return corner.getDouble(d); - } - - public double width(int d) { - return width.getDouble(d); - } - - public void setCorner(int d, double corner) { - this.corner.putScalar(d, corner); - } - - public void setWidth(int d, double width) { - this.width.putScalar(d, width); - } - - public void setWidth(INDArray width) { - this.width = width; - } - - public void setCorner(INDArray corner) { - this.corner = corner; - } - - - public boolean contains(INDArray point) { - INDArray cornerMinusWidth = corner.sub(width); - INDArray cornerPlusWidth = corner.add(width); - for (int d = 0; d < dimension; d++) { - double pointD = point.getDouble(d); - if (cornerMinusWidth.getDouble(d) > pointD) - return false; - if (cornerPlusWidth.getDouble(d) < pointD) - return false; - } - return true; - - } - - public INDArray width() { - return width; - } - - public INDArray corner() { - return corner; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java deleted file mode 100644 index 6681d3148..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; - -@Data -public class DataPoint implements Serializable { - private int index; - private INDArray point; - private long d; - private String functionName; - private boolean invert = false; - - - public DataPoint(int index, INDArray point, boolean invert) { - this(index, point, "euclidean"); - this.invert = invert; - } - - public DataPoint(int index, INDArray point, String functionName, boolean invert) { - this.index = index; - this.point = point; - this.functionName = functionName; - this.d = point.length(); - this.invert = invert; - } - - - public DataPoint(int index, INDArray point) { - this(index, point, false); - } - - public DataPoint(int index, INDArray point, String functionName) { - this(index, point, functionName, false); - } - - /** - * Euclidean distance - * @param point the distance from this point to the given point - * @return the distance between the two points - */ - public float distance(DataPoint point) { - switch (functionName) { - case "euclidean": - float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret : ret; - - case "cosinesimilarity": - float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret2 : ret2; - - case "manhattan": - float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret3 : ret3; - case "dot": - float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point); - return invert ? -dotRet : dotRet; - default: - float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret4 : ret4; - - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java deleted file mode 100644 index a5ea6ea95..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class HeapItem implements Serializable, Comparable { - private int index; - private double distance; - - - public HeapItem(int index, double distance) { - this.index = index; - this.distance = distance; - } - - public int getIndex() { - return index; - } - - public void setIndex(int index) { - this.index = index; - } - - public double getDistance() { - return distance; - } - - public void setDistance(double distance) { - this.distance = distance; - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - HeapItem heapItem = (HeapItem) o; - - if (index != heapItem.index) - return false; - return Double.compare(heapItem.distance, distance) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - result = index; - temp = Double.doubleToLongBits(distance); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - @Override - public int compareTo(HeapItem o) { - return distance < o.distance ? 1 : 0; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java deleted file mode 100644 index e68cf33ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -@Data -public class HeapObject implements Serializable, Comparable { - private int index; - private INDArray point; - private double distance; - - - public HeapObject(int index, INDArray point, double distance) { - this.index = index; - this.point = point; - this.distance = distance; - } - - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - HeapObject heapObject = (HeapObject) o; - - if (!point.equals(heapObject.point)) - return false; - - return Double.compare(heapObject.distance, distance) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - result = index; - temp = Double.doubleToLongBits(distance); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - @Override - public int compareTo(HeapObject o) { - return distance < o.distance ? 1 : 0; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java deleted file mode 100644 index 4a1bf34e4..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java +++ /dev/null @@ -1,425 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; -import lombok.val; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Set; - - -/** - * @author Adam Gibson - */ -public class SpTree implements Serializable { - - - public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL"; - - - private int D; - private INDArray data; - public final static int NODE_RATIO = 8000; - private int N; - private int size; - private int cumSize; - private Cell boundary; - private INDArray centerOfMass; - private SpTree parent; - private int[] index; - private int nodeCapacity; - private int numChildren = 2; - private boolean isLeaf = true; - private Collection indices; - private SpTree[] children; - private static Logger log = LoggerFactory.getLogger(SpTree.class); - private String similarityFunction = Distance.EUCLIDEAN.toString(); - - - - public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices, - String similarityFunction) { - init(parent, data, corner, width, indices, similarityFunction); - } - - - public SpTree(INDArray data, Collection indices, String similarityFunction) { - this.indices = indices; - this.N = data.rows(); - this.D = data.columns(); - this.similarityFunction = similarityFunction; - data = data.dup(); - INDArray meanY = data.mean(0); - INDArray minY = data.min(0); - INDArray maxY = data.max(0); - INDArray width = Nd4j.create(data.dataType(), meanY.shape()); - for (int i = 0; i < width.length(); i++) { - width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i), - meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD); - } - - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - init(null, data, meanY, width, indices, similarityFunction); - fill(N); - } - } - - - public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices) { - this(parent, data, corner, width, indices, "euclidean"); - } - - - public SpTree(INDArray data, Collection indices) { - this(data, indices, "euclidean"); - } - - - - public SpTree(INDArray data) { - this(data, new ArrayList()); - } - - public MemoryWorkspace workspace() { - return null; - } - - private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices, - String similarityFunction) { - - this.parent = parent; - D = data.columns(); - N = data.rows(); - this.similarityFunction = similarityFunction; - nodeCapacity = N % NODE_RATIO; - index = new int[nodeCapacity]; - for (int d = 1; d < this.D; d++) - numChildren *= 2; - this.indices = indices; - isLeaf = true; - size = 0; - cumSize = 0; - children = new SpTree[numChildren]; - this.data = data; - boundary = new Cell(D); - boundary.setCorner(corner.dup()); - boundary.setWidth(width.dup()); - centerOfMass = Nd4j.create(data.dataType(), D); - } - - - - private boolean insert(int index) { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - INDArray point = data.slice(index); - /*boolean contains = false; - SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains); - Nd4j.getExecutioner().exec(op); - op.getOutputArgument(0).getScalar(0); - if (!contains) return false;*/ - if (!boundary.contains(point)) - return false; - - - cumSize++; - double mult1 = (double) (cumSize - 1) / (double) cumSize; - double mult2 = 1.0 / (double) cumSize; - centerOfMass.muli(mult1); - centerOfMass.addi(point.mul(mult2)); - // If there is space in this quad tree and it is a leaf, add the object here - if (isLeaf() && size < nodeCapacity) { - this.index[size] = index; - indices.add(point); - size++; - return true; - } - - - for (int i = 0; i < size; i++) { - INDArray compPoint = data.slice(this.index[i]); - if (compPoint.equals(point)) - return true; - } - - - if (isLeaf()) - subDivide(); - - - // Find out where the point can be inserted - for (int i = 0; i < numChildren; i++) { - if (children[i].insert(index)) - return true; - } - - throw new IllegalStateException("Shouldn't reach this state"); - } - } - - - /** - * Subdivide the node in to - * 4 children - */ - public void subDivide() { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{ - - INDArray newCorner = Nd4j.create(data.dataType(), D); - INDArray newWidth = Nd4j.create(data.dataType(), D); - for (int i = 0; i < numChildren; i++) { - int div = 1; - for (int d = 0; d < D; d++) { - newWidth.putScalar(d, .5 * boundary.width(d)); - if ((i / div) % 2 == 1) - newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d)); - else - newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d)); - div *= 2; - } - - children[i] = new SpTree(this, data, newCorner, newWidth, indices); - - } - - // Move existing points to correct children - for (int i = 0; i < size; i++) { - boolean success = false; - for (int j = 0; j < this.numChildren; j++) - if (!success) - success = children[j].insert(index[i]); - - index[i] = -1; - } - - // Empty parent node - size = 0; - isLeaf = false; - } - } - - - - /** - * Compute non edge forces using barnes hut - * @param pointIndex - * @param theta - * @param negativeForce - * @param sumQ - */ - public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { - // Make sure that we spend no time on empty nodes or self-interactions - INDArray buf = Nd4j.create(data.dataType(), this.D); - - if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) - return; - /* MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - // Compute distance between point and center-of-mass - data.slice(pointIndex).subi(centerOfMass, buf); - - double D = Nd4j.getBlasWrapper().dot(buf, buf); - // Check whether we can use this node as a "summary" - double maxWidth = boundary.width().maxNumber().doubleValue(); - // Check whether we can use this node as a "summary" - if (isLeaf() || maxWidth / Math.sqrt(D) < theta) { - - // Compute and add t-SNE force between point and current node - double Q = 1.0 / (1.0 + D); - double mult = cumSize * Q; - sumQ.addAndGet(mult); - mult *= Q; - negativeForce.addi(buf.mul(mult)); - } else { - - // Recursively apply Barnes-Hut to children - for (int i = 0; i < numChildren; i++) { - children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - } - - } - } - } - - - /** - * - * Compute edge forces using barnes hut - * @param rowP a vector - * @param colP - * @param valP - * @param N the number of elements - * @param posF the positive force - */ - public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { - if (!rowP.isVector()) - throw new IllegalArgumentException("RowP must be a vector"); - - // Loop over all edges in the graph - // just execute native op - Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF)); - - /* - INDArray buf = Nd4j.create(data.dataType(), this.D); - double D; - for (int n = 0; n < N; n++) { - INDArray slice = data.slice(n); - for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { - - // Compute pairwise distance and Q-value - slice.subi(data.slice(colP.getInt(i)), buf); - - D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf); - D = valP.getDouble(i) / D; - - // Sum positive force - posF.slice(n).addi(buf.muli(D)); - } - } - */ - } - - - - public boolean isLeaf() { - return isLeaf; - } - - /** - * Verifies the structure of the tree (does bounds checking on each node) - * @return true if the structure of the tree - * is correct. - */ - public boolean isCorrect() { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - for (int n = 0; n < size; n++) { - INDArray point = data.slice(index[n]); - if (!boundary.contains(point)) - return false; - } - if (!isLeaf()) { - boolean correct = true; - for (int i = 0; i < numChildren; i++) - correct = correct && children[i].isCorrect(); - return correct; - } - - return true; - } - } - - /** - * The depth of the node - * @return the depth of the node - */ - public int depth() { - if (isLeaf()) - return 1; - int depth = 1; - int maxChildDepth = 0; - for (int i = 0; i < numChildren; i++) { - maxChildDepth = Math.max(maxChildDepth, children[0].depth()); - } - - return depth + maxChildDepth; - } - - private void fill(int n) { - if (indices.isEmpty() && parent == null) - for (int i = 0; i < n; i++) { - log.trace("Inserted " + i); - insert(i); - } - else - log.warn("Called fill already"); - } - - - public SpTree[] getChildren() { - return children; - } - - public int getD() { - return D; - } - - public INDArray getCenterOfMass() { - return centerOfMass; - } - - public Cell getBoundary() { - return boundary; - } - - public int[] getIndex() { - return index; - } - - public int getCumSize() { - return cumSize; - } - - public void setCumSize(int cumSize) { - this.cumSize = cumSize; - } - - public int getNumChildren() { - return numChildren; - } - - public void setNumChildren(int numChildren) { - this.numChildren = numChildren; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java deleted file mode 100644 index daada687f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import lombok.*; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.condition.ConvergenceCondition; -import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; - -import java.io.Serializable; - -@AllArgsConstructor(access = AccessLevel.PROTECTED) -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable { - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringStrategyType type; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected Integer initialClusterCount; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringAlgorithmCondition optimizationPhaseCondition; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringAlgorithmCondition terminationCondition; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected boolean inverse; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected Distance distanceFunction; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected boolean allowEmptyClusters; - - public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean inverse) { - this.type = type; - this.initialClusterCount = initialClusterCount; - this.distanceFunction = distanceFunction; - this.allowEmptyClusters = allowEmptyClusters; - this.inverse = inverse; - } - - public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount, - Distance distanceFunction, boolean inverse) { - this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse); - } - - - /** - * - * @param maxIterationCount - * @return - */ - public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) { - setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount)); - return this; - } - - /** - * - * @param rate - * @return - */ - public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) { - setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate)); - return this; - } - - /** - * @return - */ - @Override - public boolean inverseDistanceCalculation() { - return inverse; - } - - /** - * - * @param type - * @return - */ - public boolean isStrategyOfType(ClusteringStrategyType type) { - return type.equals(this.type); - } - - /** - * - * @return - */ - public Integer getInitialClusterCount() { - return initialClusterCount; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java deleted file mode 100644 index 2ec9fcd47..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -public interface ClusteringStrategy { - - /** - * - * @return - */ - boolean inverseDistanceCalculation(); - - /** - * - * @return - */ - ClusteringStrategyType getType(); - - /** - * - * @param type - * @return - */ - boolean isStrategyOfType(ClusteringStrategyType type); - - /** - * - * @return - */ - Integer getInitialClusterCount(); - - /** - * - * @return - */ - Distance getDistanceFunction(); - - /** - * - * @return - */ - boolean isAllowEmptyClusters(); - - /** - * - * @return - */ - ClusteringAlgorithmCondition getTerminationCondition(); - - /** - * - * @return - */ - boolean isOptimizationDefined(); - - /** - * - * @param iterationHistory - * @return - */ - boolean isOptimizationApplicableNow(IterationHistory iterationHistory); - - /** - * - * @param maxIterationCount - * @return - */ - BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount); - - /** - * - * @param rate - * @return - */ - BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java deleted file mode 100644 index 9f72bba95..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -public enum ClusteringStrategyType { - FIXED_CLUSTER_COUNT, OPTIMIZATION -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java deleted file mode 100644 index 18eceb34f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class FixedClusterCountStrategy extends BaseClusteringStrategy { - - - protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean inverse) { - super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters, - inverse); - } - - /** - * - * @param clusterCount - * @param distanceFunction - * @param inverse - * @return - */ - public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) { - return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse); - } - - /** - * @return - */ - @Override - public boolean inverseDistanceCalculation() { - return inverse; - } - - public boolean isOptimizationDefined() { - return false; - } - - public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { - return false; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java deleted file mode 100644 index dc9385296..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.condition.ConvergenceCondition; -import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimization; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; - -public class OptimisationStrategy extends BaseClusteringStrategy { - public static int defaultIterationCount = 100; - - private ClusteringOptimization clusteringOptimisation; - private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition; - - protected OptimisationStrategy() { - super(); - } - - protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) { - super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false); - } - - public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) { - return new OptimisationStrategy(initialClusterCount, distanceFunction); - } - - public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) { - clusteringOptimisation = new ClusteringOptimization(type, value); - return this; - } - - public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) { - clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value); - return this; - } - - public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) { - clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate); - return this; - } - - - public double getClusteringOptimizationValue() { - return clusteringOptimisation.getValue(); - } - - public boolean isClusteringOptimizationType(ClusteringOptimizationType type) { - return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type); - } - - public boolean isOptimizationDefined() { - return clusteringOptimisation != null; - } - - public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { - return clusteringOptimisationApplicationCondition != null - && clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java deleted file mode 100755 index 2290c6269..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java +++ /dev/null @@ -1,1327 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - - -import org.apache.commons.math3.linear.CholeskyDecomposition; -import org.apache.commons.math3.linear.NonSquareMatrixException; -import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.util.FastMath; -import org.nd4j.common.primitives.Counter; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.Set; - - -public class MathUtils { - - /** The natural logarithm of 2. */ - public static double log2 = Math.log(2); - - /** - * Normalize a value - * (val - min) / (max - min) - * @param val value to normalize - * @param max max value - * @param min min value - * @return the normalized value - */ - public static double normalize(double val, double min, double max) { - if (max < min) - throw new IllegalArgumentException("Max must be greater than min"); - - return (val - min) / (max - min); - } - - /** - * Clamps the value to a discrete value - * @param value the value to clamp - * @param min min for the probability distribution - * @param max max for the probability distribution - * @return the discrete value - */ - public static int clamp(int value, int min, int max) { - if (value < min) - value = min; - if (value > max) - value = max; - return value; - } - - /** - * Discretize the given value - * @param value the value to discretize - * @param min the min of the distribution - * @param max the max of the distribution - * @param binCount the number of bins - * @return the discretized value - */ - public static int discretize(double value, double min, double max, int binCount) { - int discreteValue = (int) (binCount * normalize(value, min, max)); - return clamp(discreteValue, 0, binCount - 1); - } - - - /** - * See: https://stackoverflow.com/questions/466204/rounding-off-to-nearest-power-of-2 - * @param v the number to getFromOrigin the next power of 2 for - * @return the next power of 2 for the passed in value - */ - public static long nextPowOf2(long v) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v++; - return v; - - } - - - - /** - * Generates a binomial distributed number using - * the given rng - * @param rng - * @param n - * @param p - * @return - */ - public static int binomial(RandomGenerator rng, int n, double p) { - if ((p < 0) || (p > 1)) { - return 0; - } - int c = 0; - for (int i = 0; i < n; i++) { - if (rng.nextDouble() < p) { - c++; - } - } - return c; - } - - /** - * Generate a uniform random number from the given rng - * @param rng the rng to use - * @param min the min num - * @param max the max num - * @return a number uniformly distributed between min and max - */ - public static double uniform(Random rng, double min, double max) { - return rng.nextDouble() * (max - min) + min; - } - - /** - * Returns the correlation coefficient of two double vectors. - * - * @param residuals residuals - * @param targetAttribute target attribute vector - * - * @return the correlation coefficient or r - */ - public static double correlation(double[] residuals, double targetAttribute[]) { - double[] predictedValues = new double[residuals.length]; - for (int i = 0; i < predictedValues.length; i++) { - predictedValues[i] = targetAttribute[i] - residuals[i]; - } - double ssErr = ssError(predictedValues, targetAttribute); - double total = ssTotal(residuals, targetAttribute); - return 1 - (ssErr / total); - }//end correlation - - /** - * 1 / 1 + exp(-x) - * @param x - * @return - */ - public static double sigmoid(double x) { - return 1.0 / (1.0 + FastMath.exp(-x)); - } - - - /** - * How much of the variance is explained by the regression - * @param residuals error - * @param targetAttribute data for target attribute - * @return the sum squares of regression - */ - public static double ssReg(double[] residuals, double[] targetAttribute) { - double mean = sum(targetAttribute) / targetAttribute.length; - double ret = 0; - for (int i = 0; i < residuals.length; i++) { - ret += Math.pow(residuals[i] - mean, 2); - } - return ret; - } - - /** - * How much of the variance is NOT explained by the regression - * @param predictedValues predicted values - * @param targetAttribute data for target attribute - * @return the sum squares of regression - */ - public static double ssError(double[] predictedValues, double[] targetAttribute) { - double ret = 0; - for (int i = 0; i < predictedValues.length; i++) { - ret += Math.pow(targetAttribute[i] - predictedValues[i], 2); - } - return ret; - - } - - - /** - * Calculate string similarity with tfidf weights relative to each character - * frequency and how many times a character appears in a given string - * @param strings the strings to calculate similarity for - * @return the cosine similarity between the strings - */ - public static double stringSimilarity(String... strings) { - if (strings == null) - return 0; - Counter counter = new Counter<>(); - Counter counter2 = new Counter<>(); - - for (int i = 0; i < strings[0].length(); i++) - counter.incrementCount(String.valueOf(strings[0].charAt(i)), 1.0f); - - for (int i = 0; i < strings[1].length(); i++) - counter2.incrementCount(String.valueOf(strings[1].charAt(i)), 1.0f); - Set v1 = counter.keySet(); - Set v2 = counter2.keySet(); - - - Set both = SetUtils.intersection(v1, v2); - - double sclar = 0, norm1 = 0, norm2 = 0; - for (String k : both) - sclar += counter.getCount(k) * counter2.getCount(k); - for (String k : v1) - norm1 += counter.getCount(k) * counter.getCount(k); - for (String k : v2) - norm2 += counter2.getCount(k) * counter2.getCount(k); - return sclar / Math.sqrt(norm1 * norm2); - } - - /** - * Returns the vector length (sqrt(sum(x_i)) - * @param vector the vector to return the vector length for - * @return the vector length of the passed in array - */ - public static double vectorLength(double[] vector) { - double ret = 0; - if (vector == null) - return ret; - else { - for (int i = 0; i < vector.length; i++) { - ret += Math.pow(vector[i], 2); - } - - } - return ret; - } - - /** - * Inverse document frequency: the total docs divided by the number of times the word - * appeared in a document - * @param totalDocs the total documents for the data applyTransformToDestination - * @param numTimesWordAppearedInADocument the number of times the word occurred in a document - * @return log(10) (totalDocs/numTImesWordAppearedInADocument) - */ - public static double idf(double totalDocs, double numTimesWordAppearedInADocument) { - //return totalDocs > 0 ? Math.log10(totalDocs/numTimesWordAppearedInADocument) : 0; - if (totalDocs == 0) - return 0; - double idf = Math.log10(totalDocs / numTimesWordAppearedInADocument); - return idf; - } - - /** - * Term frequency: 1+ log10(count) - * @param count the count of a word or character in a given string or document - * @return 1+ log(10) count - */ - public static double tf(int count, int documentLength) { - //return count > 0 ? 1 + Math.log10(count) : 0 - double tf = ((double) count / documentLength); - return tf; - } - - /** - * Return td * idf - * @param tf the term frequency (assumed calculated) - * @param idf inverse document frequency (assumed calculated) - * @return td * idf - */ - public static double tfidf(double tf, double idf) { - // System.out.println("TF-IDF Value: " + (tf * idf)); - return tf * idf; - } - - private static int charForLetter(char c) { - char[] chars = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', - 't', 'u', 'v', 'w', 'x', 'y', 'z'}; - for (int i = 0; i < chars.length; i++) - if (chars[i] == c) - return i; - return -1; - - } - - - - /** - * Total variance in target attribute - * @param residuals error - * @param targetAttribute data for target attribute - * @return Total variance in target attribute - */ - public static double ssTotal(double[] residuals, double[] targetAttribute) { - return ssReg(residuals, targetAttribute) + ssError(residuals, targetAttribute); - } - - /** - * This returns the sum of the given array. - * @param nums the array of numbers to sum - * @return the sum of the given array - */ - public static double sum(double[] nums) { - - double ret = 0; - for (double d : nums) - ret += d; - - return ret; - }//end sum - - /** - * This will merge the coordinates of the given coordinate system. - * @param x the x coordinates - * @param y the y coordinates - * @return a vector such that each (x,y) pair is at ret[i],ret[i+1] - */ - public static double[] mergeCoords(double[] x, double[] y) { - if (x.length != y.length) - throw new IllegalArgumentException( - "Sample sizes must be the same for each data applyTransformToDestination."); - double[] ret = new double[x.length + y.length]; - - for (int i = 0; i < x.length; i++) { - ret[i] = x[i]; - ret[i + 1] = y[i]; - } - return ret; - }//end mergeCoords - - /** - * This will merge the coordinates of the given coordinate system. - * @param x the x coordinates - * @param y the y coordinates - * @return a vector such that each (x,y) pair is at ret[i],ret[i+1] - */ - public static List mergeCoords(List x, List y) { - if (x.size() != y.size()) - throw new IllegalArgumentException( - "Sample sizes must be the same for each data applyTransformToDestination."); - - List ret = new ArrayList<>(); - - for (int i = 0; i < x.size(); i++) { - ret.add(x.get(i)); - ret.add(y.get(i)); - } - return ret; - }//end mergeCoords - - /** - * This returns the minimized loss values for a given vector. - * It is assumed that the x, y pairs are at - * vector[i], vector[i+1] - * @param vector the vector of numbers to getFromOrigin the weights for - * @return a double array with w_0 and w_1 are the associated indices. - */ - public static double[] weightsFor(List vector) { - /* split coordinate system */ - List coords = coordSplit(vector); - /* x vals */ - double[] x = coords.get(0); - /* y vals */ - double[] y = coords.get(1); - - - double meanX = sum(x) / x.length; - double meanY = sum(y) / y.length; - - double sumOfMeanDifferences = sumOfMeanDifferences(x, y); - double xDifferenceOfMean = sumOfMeanDifferencesOnePoint(x); - - double w_1 = sumOfMeanDifferences / xDifferenceOfMean; - - double w_0 = meanY - (w_1) * meanX; - - //double w_1=(n*sumOfProducts(x,y) - sum(x) * sum(y))/(n*sumOfSquares(x) - Math.pow(sum(x),2)); - - // double w_0=(sum(y) - (w_1 * sum(x)))/n; - - double[] ret = new double[vector.size()]; - ret[0] = w_0; - ret[1] = w_1; - - return ret; - }//end weightsFor - - /** - * This will return the squared loss of the given - * points - * @param x the x coordinates to use - * @param y the y coordinates to use - * @param w_0 the first weight - * - * @param w_1 the second weight - * @return the squared loss of the given points - */ - public static double squaredLoss(double[] x, double[] y, double w_0, double w_1) { - double sum = 0; - for (int j = 0; j < x.length; j++) { - sum += Math.pow((y[j] - (w_1 * x[j] + w_0)), 2); - } - return sum; - }//end squaredLoss - - - public static double w_1(double[] x, double[] y, int n) { - return (n * sumOfProducts(x, y) - sum(x) * sum(y)) / (n * sumOfSquares(x) - Math.pow(sum(x), 2)); - } - - public static double w_0(double[] x, double[] y, int n) { - double weight1 = w_1(x, y, n); - - return (sum(y) - (weight1 * sum(x))) / n; - } - - /** - * This returns the minimized loss values for a given vector. - * It is assumed that the x, y pairs are at - * vector[i], vector[i+1] - * @param vector the vector of numbers to getFromOrigin the weights for - * @return a double array with w_0 and w_1 are the associated indices. - */ - public static double[] weightsFor(double[] vector) { - - /* split coordinate system */ - List coords = coordSplit(vector); - /* x vals */ - double[] x = coords.get(0); - /* y vals */ - double[] y = coords.get(1); - - - double meanX = sum(x) / x.length; - double meanY = sum(y) / y.length; - - double sumOfMeanDifferences = sumOfMeanDifferences(x, y); - double xDifferenceOfMean = sumOfMeanDifferencesOnePoint(x); - - double w_1 = sumOfMeanDifferences / xDifferenceOfMean; - - double w_0 = meanY - (w_1) * meanX; - - - - double[] ret = new double[vector.length]; - ret[0] = w_0; - ret[1] = w_1; - - return ret; - }//end weightsFor - - public static double errorFor(double actual, double prediction) { - return actual - prediction; - } - - /** - * Used for calculating top part of simple regression for - * beta 1 - * @param vector the x coordinates - * @param vector2 the y coordinates - * @return the sum of mean differences for the input vectors - */ - public static double sumOfMeanDifferences(double[] vector, double[] vector2) { - double mean = sum(vector) / vector.length; - double mean2 = sum(vector2) / vector2.length; - double ret = 0; - for (int i = 0; i < vector.length; i++) { - double vec1Diff = vector[i] - mean; - double vec2Diff = vector2[i] - mean2; - ret += vec1Diff * vec2Diff; - } - return ret; - }//end sumOfMeanDifferences - - /** - * Used for calculating top part of simple regression for - * beta 1 - * @param vector the x coordinates - * @return the sum of mean differences for the input vectors - */ - public static double sumOfMeanDifferencesOnePoint(double[] vector) { - double mean = sum(vector) / vector.length; - double ret = 0; - for (int i = 0; i < vector.length; i++) { - double vec1Diff = Math.pow(vector[i] - mean, 2); - ret += vec1Diff; - } - return ret; - }//end sumOfMeanDifferences - - public static double variance(double[] vector) { - return sumOfMeanDifferencesOnePoint(vector) / vector.length; - } - - /** - * This returns the product of all numbers in the given array. - * @param nums the numbers to multiply over - * @return the product of all numbers in the array, or 0 - * if the length is or nums i null - */ - public static double times(double[] nums) { - if (nums == null || nums.length == 0) - return 0; - double ret = 1; - for (int i = 0; i < nums.length; i++) - ret *= nums[i]; - return ret; - }//end times - - - /** - * This returns the sum of products for the given - * numbers. - * @param nums the sum of products for the give numbers - * @return the sum of products for the given numbers - */ - public static double sumOfProducts(double[]... nums) { - if (nums == null || nums.length < 1) - return 0; - double sum = 0; - - for (int i = 0; i < nums.length; i++) { - /* The ith column for all of the rows */ - double[] column = column(i, nums); - sum += times(column); - - } - return sum; - }//end sumOfProducts - - - /** - * This returns the given column over an n arrays - * @param column the column to getFromOrigin values for - * @param nums the arrays to extract values from - * @return a double array containing all of the numbers in that column - * for all of the arrays. - * @throws IllegalArgumentException if the index is < 0 - */ - private static double[] column(int column, double[]... nums) throws IllegalArgumentException { - - double[] ret = new double[nums.length]; - - for (int i = 0; i < nums.length; i++) { - double[] curr = nums[i]; - ret[i] = curr[column]; - } - return ret; - }//end column - - /** - * This returns the coordinate split in a list of coordinates - * such that the values for ret[0] are the x values - * and ret[1] are the y values - * @param vector the vector to split with x and y values/ - * @return a coordinate split for the given vector of values. - * if null, is passed in null is returned - */ - public static List coordSplit(double[] vector) { - - if (vector == null) - return null; - List ret = new ArrayList<>(); - /* x coordinates */ - double[] xVals = new double[vector.length / 2]; - /* y coordinates */ - double[] yVals = new double[vector.length / 2]; - /* current points */ - int xTracker = 0; - int yTracker = 0; - for (int i = 0; i < vector.length; i++) { - //even value, x coordinate - if (i % 2 == 0) - xVals[xTracker++] = vector[i]; - //y coordinate - else - yVals[yTracker++] = vector[i]; - } - ret.add(xVals); - ret.add(yVals); - - return ret; - }//end coordSplit - - - /** - * This returns the coordinate split in a list of coordinates - * such that the values for ret[0] are the x values - * and ret[1] are the y values - * @param vector the vector to split with x and y values - * Note that the list will be more stable due to the size operator. - * The array version will have extraneous values if not monitored - * properly. - * @return a coordinate split for the given vector of values. - * if null, is passed in null is returned - */ - public static List coordSplit(List vector) { - - if (vector == null) - return null; - List ret = new ArrayList<>(); - /* x coordinates */ - double[] xVals = new double[vector.size() / 2]; - /* y coordinates */ - double[] yVals = new double[vector.size() / 2]; - /* current points */ - int xTracker = 0; - int yTracker = 0; - for (int i = 0; i < vector.size(); i++) { - //even value, x coordinate - if (i % 2 == 0) - xVals[xTracker++] = vector.get(i); - //y coordinate - else - yVals[yTracker++] = vector.get(i); - } - ret.add(xVals); - ret.add(yVals); - - return ret; - }//end coordSplit - - - - /** - * This returns the x values of the given vector. - * These are assumed to be the even values of the vector. - * @param vector the vector to getFromOrigin the values for - * @return the x values of the given vector - */ - public static double[] xVals(double[] vector) { - - - if (vector == null) - return null; - double[] x = new double[vector.length / 2]; - int count = 0; - for (int i = 0; i < vector.length; i++) { - if (i % 2 != 0) - x[count++] = vector[i]; - } - return x; - }//end xVals - - /** - * This returns the odd indexed values for the given vector - * @param vector the odd indexed values of rht egiven vector - * @return the y values of the given vector - */ - public static double[] yVals(double[] vector) { - double[] y = new double[vector.length / 2]; - int count = 0; - for (int i = 0; i < vector.length; i++) { - if (i % 2 == 0) - y[count++] = vector[i]; - } - return y; - }//end yVals - - - /** - * This returns the sum of squares for the given vector. - * - * @param vector the vector to obtain the sum of squares for - * @return the sum of squares for this vector - */ - public static double sumOfSquares(double[] vector) { - double ret = 0; - for (double d : vector) - ret += Math.pow(d, 2); - return ret; - } - - /** - * This returns the determination coefficient of two vectors given a length - * @param y1 the first vector - * @param y2 the second vector - * @param n the length of both vectors - * @return the determination coefficient or r^2 - */ - public static double determinationCoefficient(double[] y1, double[] y2, int n) { - return Math.pow(correlation(y1, y2), 2); - } - - - - /** - * Returns the logarithm of a for base 2. - * - * @param a a double - * @return the logarithm for base 2 - */ - public static double log2(double a) { - if (a == 0) - return 0.0; - return Math.log(a) / log2; - } - - /** - * This returns the slope of the given points. - * @param x1 the first x to use - * @param x2 the end x to use - * @param y1 the begin y to use - * @param y2 the end y to use - * @return the slope of the given points - */ - public double slope(double x1, double x2, double y1, double y2) { - return (y2 - y1) / (x2 - x1); - }//end slope - - /** - * This returns the root mean squared error of two data sets - * @param real the real values - * @param predicted the predicted values - * @return the root means squared error for two data sets - */ - public static double rootMeansSquaredError(double[] real, double[] predicted) { - double ret = 0.0; - for (int i = 0; i < real.length; i++) { - ret += Math.pow((real[i] - predicted[i]), 2); - } - return Math.sqrt(ret / real.length); - }//end rootMeansSquaredError - - /** - * This returns the entropy (information gain, or uncertainty of a random variable). - * @param vector the vector of values to getFromOrigin the entropy for - * @return the entropy of the given vector - */ - public static double entropy(double[] vector) { - if (vector == null || vector.length < 1) - return 0; - else { - double ret = 0; - for (double d : vector) - ret += d * Math.log(d); - return ret; - - } - }//end entropy - - /** - * This returns the kronecker delta of two doubles. - * @param i the first number to compare - * @param j the second number to compare - * @return 1 if they are equal, 0 otherwise - */ - public static int kroneckerDelta(double i, double j) { - return (i == j) ? 1 : 0; - } - - /** - * This calculates the adjusted r^2 including degrees of freedom. - * Also known as calculating "strength" of a regression - * @param rSquared the r squared value to calculate - * @param numRegressors number of variables - * @param numDataPoints size of the data applyTransformToDestination - * @return an adjusted r^2 for degrees of freedom - */ - public static double adjustedrSquared(double rSquared, int numRegressors, int numDataPoints) { - double divide = (numDataPoints - 1.0) / (numDataPoints - numRegressors - 1.0); - double rSquaredDiff = 1 - rSquared; - return 1 - (rSquaredDiff * divide); - } - - - public static double[] normalizeToOne(double[] doubles) { - normalize(doubles, sum(doubles)); - return doubles; - } - - public static double min(double[] doubles) { - double ret = doubles[0]; - for (double d : doubles) - if (d < ret) - ret = d; - return ret; - } - - public static double max(double[] doubles) { - double ret = doubles[0]; - for (double d : doubles) - if (d > ret) - ret = d; - return ret; - } - - /** - * Normalizes the doubles in the array using the given value. - * - * @param doubles the array of double - * @param sum the value by which the doubles are to be normalized - * @exception IllegalArgumentException if sum is zero or NaN - */ - public static void normalize(double[] doubles, double sum) { - - if (Double.isNaN(sum)) { - throw new IllegalArgumentException("Can't normalize array. Sum is NaN."); - } - if (sum == 0) { - // Maybe this should just be a return. - throw new IllegalArgumentException("Can't normalize array. Sum is zero."); - } - for (int i = 0; i < doubles.length; i++) { - doubles[i] /= sum; - } - }//end normalize - - /** - * Converts an array containing the natural logarithms of - * probabilities stored in a vector back into probabilities. - * The probabilities are assumed to sum to one. - * - * @param a an array holding the natural logarithms of the probabilities - * @return the converted array - */ - public static double[] logs2probs(double[] a) { - - double max = a[maxIndex(a)]; - double sum = 0.0; - - double[] result = new double[a.length]; - for (int i = 0; i < a.length; i++) { - result[i] = Math.exp(a[i] - max); - sum += result[i]; - } - - normalize(result, sum); - - return result; - }//end logs2probs - - /** - * This returns the entropy for a given vector of probabilities. - * @param probabilities the probabilities to getFromOrigin the entropy for - * @return the entropy of the given probabilities. - */ - public static double information(double[] probabilities) { - double total = 0.0; - for (double d : probabilities) { - total += (-1.0 * log2(d) * d); - } - return total; - }//end information - - /** - * - * - * Returns index of maximum element in a given - * array of doubles. First maximum is returned. - * - * @param doubles the array of doubles - * @return the index of the maximum element - */ - public static /*@pure@*/ int maxIndex(double[] doubles) { - - double maximum = 0; - int maxIndex = 0; - - for (int i = 0; i < doubles.length; i++) { - if ((i == 0) || (doubles[i] > maximum)) { - maxIndex = i; - maximum = doubles[i]; - } - } - - return maxIndex; - }//end maxIndex - - /** - * This will return the factorial of the given number n. - * @param n the number to getFromOrigin the factorial for - * @return the factorial for this number - */ - public static double factorial(double n) { - if (n == 1 || n == 0) - return 1; - for (double i = n; i > 0; i--, n *= (i > 0 ? i : 1)) { - } - return n; - }//end factorial - - - - /** The small deviation allowed in double comparisons. */ - public static double SMALL = 1e-6; - - /** - * Returns the log-odds for a given probability. - * - * @param prob the probability - * - * @return the log-odds after the probability has been mapped to - * [Utils.SMALL, 1-Utils.SMALL] - */ - public static /*@pure@*/ double probToLogOdds(double prob) { - - if (gr(prob, 1) || (sm(prob, 0))) { - throw new IllegalArgumentException("probToLogOdds: probability must " + "be in [0,1] " + prob); - } - double p = SMALL + (1.0 - 2 * SMALL) * prob; - return Math.log(p / (1 - p)); - } - - /** - * Rounds a double to the next nearest integer value. The JDK version - * of it doesn't work properly. - * - * @param value the double value - * @return the resulting integer value - */ - public static /*@pure@*/ int round(double value) { - - return value > 0 ? (int) (value + 0.5) : -(int) (Math.abs(value) + 0.5); - }//end round - - /** - * This returns the permutation of n choose r. - * @param n the n to choose - * @param r the number of elements to choose - * @return the permutation of these numbers - */ - public static double permutation(double n, double r) { - double nFac = MathUtils.factorial(n); - double nMinusRFac = MathUtils.factorial((n - r)); - return nFac / nMinusRFac; - }//end permutation - - - /** - * This returns the combination of n choose r - * @param n the number of elements overall - * @param r the number of elements to choose - * @return the amount of possible combinations for this applyTransformToDestination of elements - */ - public static double combination(double n, double r) { - double nFac = MathUtils.factorial(n); - double rFac = MathUtils.factorial(r); - double nMinusRFac = MathUtils.factorial((n - r)); - - return nFac / (rFac * nMinusRFac); - }//end combination - - - /** - * sqrt(a^2 + b^2) without under/overflow. - */ - public static double hypotenuse(double a, double b) { - double r; - if (Math.abs(a) > Math.abs(b)) { - r = b / a; - r = Math.abs(a) * Math.sqrt(1 + r * r); - } else if (b != 0) { - r = a / b; - r = Math.abs(b) * Math.sqrt(1 + r * r); - } else { - r = 0.0; - } - return r; - }//end hypotenuse - - /** - * Rounds a double to the next nearest integer value in a probabilistic - * fashion (e.g. 0.8 has a 20% chance of being rounded down to 0 and a - * 80% chance of being rounded up to 1). In the limit, the average of - * the rounded numbers generated by this procedure should converge to - * the original double. - * - * @param value the double value - * @param rand the random number generator - * @return the resulting integer value - */ - public static int probRound(double value, Random rand) { - - if (value >= 0) { - double lower = Math.floor(value); - double prob = value - lower; - if (rand.nextDouble() < prob) { - return (int) lower + 1; - } else { - return (int) lower; - } - } else { - double lower = Math.floor(Math.abs(value)); - double prob = Math.abs(value) - lower; - if (rand.nextDouble() < prob) { - return -((int) lower + 1); - } else { - return -(int) lower; - } - } - }//end probRound - - /** - * Rounds a double to the given number of decimal places. - * - * @param value the double value - * @param afterDecimalPoint the number of digits after the decimal point - * @return the double rounded to the given precision - */ - public static /*@pure@*/ double roundDouble(double value, int afterDecimalPoint) { - - double mask = Math.pow(10.0, (double) afterDecimalPoint); - - return (double) (Math.round(value * mask)) / mask; - }//end roundDouble - - - - /** - * Rounds a double to the given number of decimal places. - * - * @param value the double value - * @param afterDecimalPoint the number of digits after the decimal point - * @return the double rounded to the given precision - */ - public static /*@pure@*/ float roundFloat(float value, int afterDecimalPoint) { - - float mask = (float) Math.pow(10, (float) afterDecimalPoint); - - return (float) (Math.round(value * mask)) / mask; - }//end roundDouble - - /** - * This will return the bernoulli trial for the given event. - * A bernoulli trial is a mechanism for detecting the probability - * of a given event occurring k times in n independent trials - * @param n the number of trials - * @param k the number of times the target event occurs - * @param successProb the probability of the event happening - * @return the probability of the given event occurring k times. - */ - public static double bernoullis(double n, double k, double successProb) { - - double combo = MathUtils.combination(n, k); - double q = 1 - successProb; - return combo * Math.pow(successProb, k) * Math.pow(q, n - k); - }//end bernoullis - - /** - * Tests if a is smaller than b. - * - * @param a a double - * @param b a double - */ - public static /*@pure@*/ boolean sm(double a, double b) { - - return (b - a > SMALL); - } - - /** - * Tests if a is greater than b. - * - * @param a a double - * @param b a double - */ - public static /*@pure@*/ boolean gr(double a, double b) { - - return (a - b > SMALL); - } - - /** - * This will take a given string and separator and convert it to an equivalent - * double array. - * @param data the data to separate - * @param separator the separator to use - * @return the new double array based on the given data - */ - public static double[] fromString(String data, String separator) { - String[] split = data.split(separator); - double[] ret = new double[split.length]; - for (int i = 0; i < split.length; i++) { - ret[i] = Double.parseDouble(split[i]); - } - return ret; - }//end fromString - - /** - * Computes the mean for an array of doubles. - * - * @param vector the array - * @return the mean - */ - public static /*@pure@*/ double mean(double[] vector) { - - double sum = 0; - - if (vector.length == 0) { - return 0; - } - for (int i = 0; i < vector.length; i++) { - sum += vector[i]; - } - return sum / (double) vector.length; - }//end mean - - /** - * This will return the cholesky decomposition of - * the given matrix - * @param m the matrix to convert - * @return the cholesky decomposition of the given - * matrix. - * See: - * http://en.wikipedia.org/wiki/Cholesky_decomposition - * @throws NonSquareMatrixException - */ - public CholeskyDecomposition choleskyFromMatrix(RealMatrix m) throws Exception { - return new CholeskyDecomposition(m); - }//end choleskyFromMatrix - - - - /** - * This will convert the given binary string to a decimal based - * integer - * @param binary the binary string to convert - * @return an equivalent base 10 number - */ - public static int toDecimal(String binary) { - long num = Long.parseLong(binary); - long rem; - /* Use the remainder method to ensure validity */ - while (num > 0) { - rem = num % 10; - num = num / 10; - if (rem != 0 && rem != 1) { - System.out.println("This is not a binary number."); - System.out.println("Please try once again."); - return -1; - } - } - return Integer.parseInt(binary, 2); - }//end toDecimal - - - /** - * This will translate a vector in to an equivalent integer - * @param vector the vector to translate - * @return a z value such that the value is the interleaved lsd to msd for each - * double in the vector - */ - public static int distanceFinderZValue(double[] vector) { - StringBuilder binaryBuffer = new StringBuilder(); - List binaryReps = new ArrayList<>(vector.length); - for (int i = 0; i < vector.length; i++) { - double d = vector[i]; - int j = (int) d; - String binary = Integer.toBinaryString(j); - binaryReps.add(binary); - } - //append from left to right, the least to the most significant bit - //till all strings are empty - while (!binaryReps.isEmpty()) { - for (int j = 0; j < binaryReps.size(); j++) { - String curr = binaryReps.get(j); - if (!curr.isEmpty()) { - char first = curr.charAt(0); - binaryBuffer.append(first); - curr = curr.substring(1); - binaryReps.set(j, curr); - } else - binaryReps.remove(j); - } - } - return Integer.parseInt(binaryBuffer.toString(), 2); - - }//end distanceFinderZValue - - /** - * This returns the distance of two vectors - * sum(i=1,n) (q_i - p_i)^2 - * @param p the first vector - * @param q the second vector - * @return the distance between two vectors - */ - public static double euclideanDistance(double[] p, double[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double diff = (q[i] - p[i]); - double sq = Math.pow(diff, 2); - ret += sq; - } - return ret; - - }//end euclideanDistance - - /** - * This returns the distance of two vectors - * sum(i=1,n) (q_i - p_i)^2 - * @param p the first vector - * @param q the second vector - * @return the distance between two vectors - */ - public static double euclideanDistance(float[] p, float[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double diff = (q[i] - p[i]); - double sq = Math.pow(diff, 2); - ret += sq; - } - return ret; - - }//end euclideanDistance - - /** - * This will generate a series of uniformally distributed - * numbers between l times - * @param l the number of numbers to generate - * @return l uniformally generated numbers - */ - public static double[] generateUniform(int l) { - double[] ret = new double[l]; - Random rgen = new Random(); - for (int i = 0; i < l; i++) { - ret[i] = rgen.nextDouble(); - } - return ret; - }//end generateUniform - - - /** - * This will calculate the Manhattan distance between two sets of points. - * The Manhattan distance is equivalent to: - * 1_sum_n |p_i - q_i| - * @param p the first point vector - * @param q the second point vector - * @return the Manhattan distance between two object - */ - public static double manhattanDistance(double[] p, double[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double difference = p[i] - q[i]; - ret += Math.abs(difference); - } - return ret; - }//end manhattanDistance - - - - public static double[] sampleDoublesInInterval(double[][] doubles, int l) { - double[] sample = new double[l]; - for (int i = 0; i < l; i++) { - int rand1 = randomNumberBetween(0, doubles.length - 1); - int rand2 = randomNumberBetween(0, doubles[i].length); - sample[i] = doubles[rand1][rand2]; - } - - return sample; - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (Math.random() * ((end - begin) + 1)); - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end, RandomGenerator rng) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (rng.nextDouble() * ((end - begin) + 1)); - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end, org.nd4j.linalg.api.rng.Random rng) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (rng.nextDouble() * ((end - begin) + 1)); - } - - /** - * - * @param begin - * @param end - * @return - */ - public static float randomFloatBetween(float begin, float end) { - float rand = (float) Math.random(); - return begin + (rand * ((end - begin))); - } - - public static double randomDoubleBetween(double begin, double end) { - return begin + (Math.random() * ((end - begin))); - } - - public static void shuffleArray(int[] array, long rngSeed) { - shuffleArray(array, new Random(rngSeed)); - } - - public static void shuffleArray(int[] array, Random rng) { - //https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm - for (int i = array.length - 1; i > 0; i--) { - int j = rng.nextInt(i + 1); - int temp = array[j]; - array[j] = array[i]; - array[i] = temp; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java deleted file mode 100644 index c147c474e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; -import java.util.concurrent.*; - -public class MultiThreadUtils { - - private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); - - private static ExecutorService instance; - - private MultiThreadUtils() {} - - public static synchronized ExecutorService newExecutorService() { - int nThreads = Runtime.getRuntime().availableProcessors(); - return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue(), - new ThreadFactory() { - @Override - public Thread newThread(Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); - t.setDaemon(true); - return t; - } - }); - } - - public static void parallelTasks(final List tasks, ExecutorService executorService) { - int tasksCount = tasks.size(); - final CountDownLatch latch = new CountDownLatch(tasksCount); - for (int i = 0; i < tasksCount; i++) { - final int taskIdx = i; - executorService.execute(new Runnable() { - public void run() { - try { - tasks.get(taskIdx).run(); - } catch (Throwable e) { - log.info("Unchecked exception thrown by task", e); - } finally { - latch.countDown(); - } - } - }); - } - - try { - latch.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java deleted file mode 100755 index eecf576d0..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; - -public class SetUtils { - private SetUtils() {} - - // Set specific operations - - public static Set intersection(Collection parentCollection, Collection removeFromCollection) { - Set results = new HashSet<>(parentCollection); - results.retainAll(removeFromCollection); - return results; - } - - public static boolean intersectionP(Set s1, Set s2) { - for (T elt : s1) { - if (s2.contains(elt)) - return true; - } - return false; - } - - public static Set union(Set s1, Set s2) { - Set s3 = new HashSet<>(s1); - s3.addAll(s2); - return s3; - } - - /** Return is s1 \ s2 */ - - public static Set difference(Collection s1, Collection s2) { - Set s3 = new HashSet<>(s1); - s3.removeAll(s2); - return s3; - } -} - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java deleted file mode 100644 index e4f699289..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java +++ /dev/null @@ -1,633 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.deeplearning4j.clustering.sptree.HeapObject; -import org.deeplearning4j.clustering.util.MathUtils; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; - -@Slf4j -@Builder -@AllArgsConstructor -public class VPTree implements Serializable { - private static final long serialVersionUID = 1L; - - public static final String EUCLIDEAN = "euclidean"; - private double tau; - @Getter - @Setter - private INDArray items; - private List itemsList; - private Node root; - private String similarityFunction; - @Getter - private boolean invert = false; - private transient ExecutorService executorService; - @Getter - private int workers = 1; - private AtomicInteger size = new AtomicInteger(0); - - private transient ThreadLocal scalars = new ThreadLocal<>(); - - private WorkspaceConfiguration workspaceConfiguration; - - protected VPTree() { - // method for serialization only - scalars = new ThreadLocal<>(); - } - - /** - * - * @param points - * @param invert - */ - public VPTree(INDArray points, boolean invert) { - this(points, "euclidean", 1, invert); - } - - /** - * - * @param points - * @param invert - * @param workers number of parallel workers for tree building (increases memory requirements!) - */ - public VPTree(INDArray points, boolean invert, int workers) { - this(points, "euclidean", workers, invert); - } - - /** - * - * @param items the items to use - * @param similarityFunction the similarity function to use - * @param invert whether to invert the distance (similarity functions have different min/max objectives) - */ - public VPTree(INDArray items, String similarityFunction, boolean invert) { - this.similarityFunction = similarityFunction; - this.invert = invert; - this.items = items; - root = buildFromPoints(items); - workers = 1; - } - - /** - * - * @param items the items to use - * @param similarityFunction the similarity function to use - * @param workers number of parallel workers for tree building (increases memory requirements!) - * @param invert whether to invert the metric (different optimization objective) - */ - public VPTree(List items, String similarityFunction, int workers, boolean invert) { - this.workers = workers; - - val list = new INDArray[items.size()]; - - // build list of INDArrays first - for (int i = 0; i < items.size(); i++) - list[i] = items.get(i).getPoint(); - //this.items.putRow(i, items.get(i).getPoint()); - - // just stack them out with concat :) - this.items = Nd4j.pile(list); - - this.invert = invert; - this.similarityFunction = similarityFunction; - root = buildFromPoints(this.items); - } - - - - /** - * - * @param items - * @param similarityFunction - */ - public VPTree(INDArray items, String similarityFunction) { - this(items, similarityFunction, 1, false); - } - - /** - * - * @param items - * @param similarityFunction - * @param workers number of parallel workers for tree building (increases memory requirements!) - * @param invert - */ - public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) { - this.similarityFunction = similarityFunction; - this.invert = invert; - this.items = items; - - this.workers = workers; - root = buildFromPoints(items); - } - - - /** - * - * @param items - * @param similarityFunction - */ - public VPTree(List items, String similarityFunction) { - this(items, similarityFunction, 1, false); - } - - - /** - * - * @param items - */ - public VPTree(INDArray items) { - this(items, EUCLIDEAN); - } - - - /** - * - * @param items - */ - public VPTree(List items) { - this(items, EUCLIDEAN); - } - - /** - * Create an ndarray - * from the datapoints - * @param data - * @return - */ - public static INDArray buildFromData(List data) { - INDArray ret = Nd4j.create(data.size(), data.get(0).getD()); - for (int i = 0; i < ret.slices(); i++) - ret.putSlice(i, data.get(i).getPoint()); - return ret; - } - - - - /** - * - * @param basePoint - * @param distancesArr - */ - public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) { - switch (similarityFunction) { - case "euclidean": - Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1)); - break; - case "cosinedistance": - Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1)); - break; - case "cosinesimilarity": - Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1)); - break; - case "manhattan": - Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1)); - break; - case "dot": - Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1)); - break; - case "jaccard": - Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1)); - break; - case "hamming": - Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1)); - break; - default: - Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1)); - break; - - } - - if (invert) - distancesArr.negi(); - - } - - public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) { - calcDistancesRelativeTo(items, basePoint, distancesArr); - } - - - /** - * Euclidean distance - * @return the distance between the two points - */ - public double distance(INDArray arr1, INDArray arr2) { - if (scalars == null) - scalars = new ThreadLocal<>(); - - if (scalars.get() == null) - scalars.set(Nd4j.scalar(arr1.dataType(), 0.0)); - - switch (similarityFunction) { - case "jaccard": - double ret7 = Nd4j.getExecutioner() - .execAndReturn(new JaccardDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret7 : ret7; - case "hamming": - double ret8 = Nd4j.getExecutioner() - .execAndReturn(new HammingDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret8 : ret8; - case "euclidean": - double ret = Nd4j.getExecutioner() - .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret : ret; - case "cosinesimilarity": - double ret2 = Nd4j.getExecutioner() - .execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret2 : ret2; - case "cosinedistance": - double ret6 = Nd4j.getExecutioner() - .execAndReturn(new CosineDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret6 : ret6; - case "manhattan": - double ret3 = Nd4j.getExecutioner() - .execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret3 : ret3; - case "dot": - double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2); - return invert ? -dotRet : dotRet; - default: - double ret4 = Nd4j.getExecutioner() - .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret4 : ret4; - - } - } - - protected class NodeBuilder implements Callable { - protected List list; - protected List indices; - - public NodeBuilder(List list, List indices) { - this.list = list; - this.indices = indices; - } - - @Override - public Node call() throws Exception { - return buildFromPoints(list, indices); - } - } - - private Node buildFromPoints(List points, List indices) { - Node ret = new Node(0, 0); - - - // nothing to sort here - if (points.size() == 1) { - ret.point = points.get(0); - ret.index = indices.get(0); - return ret; - } - - // opening workspace, and creating it if that's the first call - /* MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ - - INDArray items = Nd4j.vstack(points); - int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); - INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint); - ret.point = basePoint; - ret.index = indices.get(randomPoint); - INDArray distancesArr = Nd4j.create(items.rows(), 1); - - calcDistancesRelativeTo(items, basePoint, distancesArr); - - double medianDistance = distancesArr.medianNumber().doubleValue(); - - ret.threshold = (float) medianDistance; - - List leftPoints = new ArrayList<>(); - List leftIndices = new ArrayList<>(); - List rightPoints = new ArrayList<>(); - List rightIndices = new ArrayList<>(); - - for (int i = 0; i < distancesArr.length(); i++) { - if (i == randomPoint) - continue; - - if (distancesArr.getDouble(i) < medianDistance) { - leftPoints.add(points.get(i)); - leftIndices.add(indices.get(i)); - } else { - rightPoints.add(points.get(i)); - rightIndices.add(indices.get(i)); - } - } - - // closing workspace - //workspace.notifyScopeLeft(); - //log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes()); - - if (workers > 1) { - if (!leftPoints.isEmpty()) - ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints); - - if (!rightPoints.isEmpty()) - ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices)); - } else { - if (!leftPoints.isEmpty()) - ret.left = buildFromPoints(leftPoints, leftIndices); - - if (!rightPoints.isEmpty()) - ret.right = buildFromPoints(rightPoints, rightIndices); - } - - return ret; - } - - private Node buildFromPoints(INDArray items) { - if (executorService == null && items == this.items && workers > 1) { - final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - - executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { - @Override - public Thread newThread(final Runnable r) { - Thread t = new Thread(new Runnable() { - - @Override - public void run() { - Nd4j.getAffinityManager().unsafeSetDevice(deviceId); - r.run(); - } - }); - - t.setDaemon(true); - t.setName("VPTree thread"); - - return t; - } - }); - } - - - final Node ret = new Node(0, 0); - size.incrementAndGet(); - - /*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) - .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE).build(); - - // opening workspace - MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ - - int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); - INDArray basePoint = items.getRow(randomPoint, true); - INDArray distancesArr = Nd4j.create(items.rows(), 1); - ret.point = basePoint; - ret.index = randomPoint; - - calcDistancesRelativeTo(items, basePoint, distancesArr); - - double medianDistance = distancesArr.medianNumber().doubleValue(); - - ret.threshold = (float) medianDistance; - - List leftPoints = new ArrayList<>(); - List leftIndices = new ArrayList<>(); - List rightPoints = new ArrayList<>(); - List rightIndices = new ArrayList<>(); - - for (int i = 0; i < distancesArr.length(); i++) { - if (i == randomPoint) - continue; - - if (distancesArr.getDouble(i) < medianDistance) { - leftPoints.add(items.getRow(i, true)); - leftIndices.add(i); - } else { - rightPoints.add(items.getRow(i, true)); - rightIndices.add(i); - } - } - - // closing workspace - //workspace.notifyScopeLeft(); - //workspace.destroyWorkspace(true); - - if (!leftPoints.isEmpty()) - ret.left = buildFromPoints(leftPoints, leftIndices); - - if (!rightPoints.isEmpty()) - ret.right = buildFromPoints(rightPoints, rightIndices); - - // destroy once again - //workspace.destroyWorkspace(true); - - if (ret.left != null) - ret.left.fetchFutures(); - - if (ret.right != null) - ret.right.fetchFutures(); - - if (executorService != null) - executorService.shutdown(); - - return ret; - } - - public void search(@NonNull INDArray target, int k, List results, List distances) { - search(target, k, results, distances, true); - } - - public void search(@NonNull INDArray target, int k, List results, List distances, - boolean filterEqual) { - search(target, k, results, distances, filterEqual, false); - } - /** - * - * @param target - * @param k - * @param results - * @param distances - */ - public void search(@NonNull INDArray target, int k, List results, List distances, - boolean filterEqual, boolean dropEdge) { - if (items != null) - if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1) - throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " - + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); - - k = Math.min(k, items.rows()); - results.clear(); - distances.clear(); - - PriorityQueue pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); - - search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE); - - while (!pq.isEmpty()) { - HeapObject ho = pq.peek(); - results.add(new DataPoint(ho.getIndex(), ho.getPoint())); - distances.add(ho.getDistance()); - pq.poll(); - } - - Collections.reverse(results); - Collections.reverse(distances); - - if (dropEdge || results.size() > k) { - if (filterEqual && distances.get(0) == 0.0) { - results.remove(0); - distances.remove(0); - } - - while (results.size() > k) { - results.remove(results.size() - 1); - distances.remove(distances.size() - 1); - } - } - } - - /** - * - * @param node - * @param target - * @param k - * @param pq - */ - public void search(Node node, INDArray target, int k, PriorityQueue pq, double cTau) { - - if (node == null) - return; - - double tau = cTau; - - INDArray get = node.getPoint(); //items.getRow(node.getIndex()); - double distance = distance(get, target); - if (distance < tau) { - if (pq.size() == k) - pq.poll(); - - pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance)); - if (pq.size() == k) - tau = pq.peek().getDistance(); - } - - Node left = node.getLeft(); - Node right = node.getRight(); - - if (left == null && right == null) - return; - - if (distance < node.getThreshold()) { - if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first - search(left, target, k, pq, tau); - } - - if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child - search(right, target, k, pq, tau); - } - - } else { - if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first - search(right, target, k, pq, tau); - } - - if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child - search(left, target, k, pq, tau); - } - } - - } - - - protected class HeapObjectComparator implements Comparator { - - @Override - public int compare(HeapObject o1, HeapObject o2) { - return Double.compare(o2.getDistance(), o1.getDistance()); - } - } - - @Data - public static class Node implements Serializable { - private static final long serialVersionUID = 2L; - - private int index; - private float threshold; - private Node left, right; - private INDArray point; - protected transient Future futureLeft; - protected transient Future futureRight; - - public Node(int index, float threshold) { - this.index = index; - this.threshold = threshold; - } - - - public void fetchFutures() { - try { - if (futureLeft != null) { - /*while (!futureLeft.isDone()) - Thread.sleep(100);*/ - - - left = futureLeft.get(); - } - - if (futureRight != null) { - /*while (!futureRight.isDone()) - Thread.sleep(100);*/ - - right = futureRight.get(); - } - - - if (left != null) - left.fetchFutures(); - - if (right != null) - right.fetchFutures(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java deleted file mode 100644 index 2cf87d69b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.Getter; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -public class VPTreeFillSearch { - private VPTree vpTree; - private int k; - @Getter - private List results; - @Getter - private List distances; - private INDArray target; - - public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) { - this.vpTree = vpTree; - this.k = k; - this.target = target; - } - - public void search() { - results = new ArrayList<>(); - distances = new ArrayList<>(); - //initial search - //vpTree.search(target,k,results,distances); - - //fill till there is k results - //by going down the list - // if(results.size() < k) { - INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); - vpTree.calcDistancesRelativeTo(target, distancesArr); - INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); - results.clear(); - distances.clear(); - if (vpTree.getItems().isVector()) { - for (int i = 0; i < k; i++) { - int idx = sortWithIndices[0].getInt(i); - results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); - distances.add(sortWithIndices[1].getDouble(idx)); - } - } else { - for (int i = 0; i < k; i++) { - int idx = sortWithIndices[0].getInt(i); - results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); - //distances.add(sortWithIndices[1].getDouble(idx)); - distances.add(sortWithIndices[1].getDouble(i)); - } - } - - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java deleted file mode 100644 index 49d19a719..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java deleted file mode 100644 index 5a83fa85b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -public class ClusterSetTest { - @Test - public void testGetMostPopulatedClusters() { - ClusterSet clusterSet = new ClusterSet(false); - List clusters = new ArrayList<>(); - for (int i = 0; i < 5; i++) { - Cluster cluster = new Cluster(); - cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); - clusters.add(cluster); - } - clusterSet.setClusters(clusters); - List mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java deleted file mode 100644 index e436d62f5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ /dev/null @@ -1,422 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import lombok.val; -import org.deeplearning4j.BaseDL4JTest; -import org.joda.time.Duration; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; -import org.nd4j.shade.guava.base.Stopwatch; -import org.nd4j.shade.guava.primitives.Doubles; -import org.nd4j.shade.guava.primitives.Floats; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; - -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class KDTreeTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - - private KDTree kdTree; - - @BeforeClass - public static void beforeClass(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @Before - public void setUp() { - kdTree = new KDTree(2); - float[] data = new float[]{7,2}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{5,4}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{2,3}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{4,7}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{9,6}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{8,1}; - kdTree.insert(Nd4j.createFromArray(data)); - } - - @Test - public void testTree() { - KDTree tree = new KDTree(2); - INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT); - INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT); - tree.insert(half); - tree.insert(one); - Pair pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT)); - assertEquals(half, pair.getValue()); - } - - @Test - public void testInsert() { - int elements = 10; - List digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); - - KDTree kdTree = new KDTree(digits.size()); - List> lists = new ArrayList<>(); - for (int i = 0; i < elements; i++) { - List thisList = new ArrayList<>(digits.size()); - for (int k = 0; k < digits.size(); k++) { - thisList.add(digits.get(k) + i); - } - lists.add(thisList); - } - - for (int i = 0; i < elements; i++) { - double[] features = Doubles.toArray(lists.get(i)); - INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); - kdTree.insert(ind); - assertEquals(i + 1, kdTree.size()); - } - } - - @Test - public void testDelete() { - int elements = 10; - List digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); - - KDTree kdTree = new KDTree(digits.size()); - List> lists = new ArrayList<>(); - for (int i = 0; i < elements; i++) { - List thisList = new ArrayList<>(digits.size()); - for (int k = 0; k < digits.size(); k++) { - thisList.add(digits.get(k) + i); - } - lists.add(thisList); - } - - INDArray toDelete = Nd4j.empty(DataType.DOUBLE), - leafToDelete = Nd4j.empty(DataType.DOUBLE); - for (int i = 0; i < elements; i++) { - double[] features = Doubles.toArray(lists.get(i)); - INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); - if (i == 1) - toDelete = ind; - if (i == elements - 1) { - leafToDelete = ind; - } - kdTree.insert(ind); - assertEquals(i + 1, kdTree.size()); - } - - kdTree.delete(toDelete); - assertEquals(9, kdTree.size()); - kdTree.delete(leafToDelete); - assertEquals(8, kdTree.size()); - } - - @Test - public void testNN() { - int n = 10; - - // make a KD-tree of dimension {#n} - KDTree kdTree = new KDTree(n); - for (int i = -1; i < n; i++) { - // Insert a unit vector along each dimension - List vec = new ArrayList<>(n); - // i = -1 ensures the origin is in the Tree - for (int k = 0; k < n; k++) { - vec.add((k == i) ? 1.0 : 0.0); - } - INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT); - kdTree.insert(indVec); - } - Random rand = new Random(); - - // random point in the Hypercube - List pt = new ArrayList(n); - for (int k = 0; k < n; k++) { - pt.add(rand.nextDouble()); - } - Pair result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT)); - - // Always true for points in the unitary hypercube - assertTrue(result.getKey() < Double.MAX_VALUE); - - } - - @Test - public void testKNN() { - int dimensions = 512; - int vectorsNo = isIntegrationTests() ? 50000 : 1000; - // make a KD-tree of dimension {#dimensions} - Stopwatch stopwatch = Stopwatch.createStarted(); - KDTree kdTree = new KDTree(dimensions); - for (int i = -1; i < vectorsNo; i++) { - // Insert a unit vector along each dimension - INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); - kdTree.insert(indVec); - } - stopwatch.stop(); - System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); - - Random rand = new Random(); - // random point in the Hypercube - List pt = new ArrayList(dimensions); - for (int k = 0; k < dimensions; k++) { - pt.add(rand.nextFloat() * 10.0); - } - stopwatch.reset(); - stopwatch.start(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); - stopwatch.stop(); - System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); - } - - @Test - public void testKNN_Simple() { - int n = 2; - KDTree kdTree = new KDTree(n); - - float[] data = new float[]{3,3}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{1,1}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{2,2}; - kdTree.insert(Nd4j.createFromArray(data)); - - data = new float[]{0,0}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); - - assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - - assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - - assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - } - - @Test - public void testKNN_1() { - - assertEquals(6, kdTree.size()); - - float[] data = new float[]{8,1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_2() { - float[] data = new float[]{8, 1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_3() { - - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - - @Test - public void testKNN_4() { - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_5() { - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - @Test - public void test_KNN_6() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); - } - - @Test - public void test_KNN_7() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - } - - @Test - public void test_KNN_8() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); - } - - @Test - public void testNoDuplicates() { - int N = 100; - KDTree bigTree = new KDTree(2); - - List points = new ArrayList<>(); - for (int i = 0; i < N; ++i) { - double[] data = new double[]{i, i}; - points.add(Nd4j.createFromArray(data)); - } - - for (int i = 0; i < N; ++i) { - bigTree.insert(points.get(i)); - } - - assertEquals(N, bigTree.size()); - - INDArray node = Nd4j.empty(DataType.DOUBLE); - for (int i = 0; i < N; ++i) { - node = bigTree.delete(node.isEmpty() ? points.get(i) : node); - } - - assertEquals(0, bigTree.size()); - } - - @Ignore - @Test - public void performanceTest() { - int n = 2; - int num = 100000; - // make a KD-tree of dimension {#n} - long start = System.currentTimeMillis(); - KDTree kdTree = new KDTree(n); - INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n); - for (int i = 0 ; i < num; ++i) { - kdTree.insert(inputArrray.getRow(i)); - } - - long end = System.currentTimeMillis(); - Duration duration = new Duration(start, end); - System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); - - List pt = new ArrayList(num); - for (int k = 0; k < n; k++) { - pt.add((float)(num / 2)); - } - start = System.currentTimeMillis(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); - end = System.currentTimeMillis(); - duration = new Duration(start, end); - long elapsed = end - start; - System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis()); - for (val pair : list) { - System.out.println(pair.getFirst() + " " + pair.getSecond()) ; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java deleted file mode 100644 index e3a2467ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ /dev/null @@ -1,289 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kmeans; - -import lombok.val; -import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.cluster.*; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.List; - -import static org.junit.Assert.*; - -public class KMeansTest extends BaseDL4JTest { - - private boolean[] useKMeansPlusPlus = {true, false}; - - @Override - public long getTimeoutMilliseconds() { - return 60000L; - } - - @Test - public void testKMeans() { - Nd4j.getRandom().setSeed(7); - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode); - List points = Point.toPoints(Nd4j.randn(5, 5)); - ClusterSet clusterSet = kMeansClustering.applyTo(points); - PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); - System.out.println(pointClassification); - } - } - - @Test - public void testKmeansCosine() { - - Nd4j.getRandom().setSeed(7); - int numClusters = 5; - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode); - List points = Point.toPoints(Nd4j.rand(5, 300)); - ClusterSet clusterSet = kMeansClustering.applyTo(points); - PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); - - - KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode); - ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points); - PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0)); - System.out.println("Cosine " + pointClassification); - System.out.println("Euclidean " + pointClassificationEuclidean); - - assertEquals(pointClassification.getCluster().getPoints().get(0), - pointClassificationEuclidean.getCluster().getPoints().get(0)); - } - } - - @Ignore - @Test - public void testPerformanceAllIterations() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 20; - for (boolean mode : useKMeansPlusPlus) { - StopWatch watch = new StopWatch(); - watch.start(); - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode); - List points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - } - } - - @Test - @Ignore - public void testPerformanceWithConvergence() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 20; - for (boolean mode : useKMeansPlusPlus) { - StopWatch watch = new StopWatch(); - watch.start(); - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode); - - List points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - - watch.reset(); - watch.start(); - kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode); - - points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); - - clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - } - } - - @Test - public void testCorrectness() { - - /*for (int c = 0; c < 10; ++c)*/ { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 3; - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode); - double[] data = new double[]{ - 15, 16, - 16, 18.5, - 17, 20.2, - 16.4, 17.12, - 17.23, 18.12, - 43, 43, - 44.43, 45.212, - 45.8, 54.23, - 46.313, 43.123, - 50.21, 46.3, - 99, 99.22, - 100.32, 98.123, - 100.32, 97.423, - 102, 93.23, - 102.23, 94.23 - }; - List points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - - - INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850}); - INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500}); - INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990}); - - /*List clusters = clusterSet.getClusters(); - assertEquals(row0, clusters.get(0).getCenter().getArray()); - assertEquals(row1, clusters.get(1).getCenter().getArray()); - assertEquals(row2, clusters.get(2).getCenter().getArray());*/ - - PointClassification pointClassification = null; - for (Point p : points) { - pointClassification = clusterSet.classifyPoint(p); - System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray()); - List clusters = clusterSet.getClusters(); - for (int i = 0; i < clusters.size(); ++i) - System.out.println("Choice: " + clusters.get(i).getCenter().getArray()); - } - } - /*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}), - pointClassification.getCluster().getCenter().getArray());*/ - - /*clusters = clusterSet.getClusters(); - assertEquals(row0, clusters.get(0).getCenter().getArray()); - assertEquals(row1, clusters.get(1).getCenter().getArray()); - assertEquals(row2, clusters.get(2).getCenter().getArray());*/ - } - } - - @Test - public void testCentersHolder() { - int rows = 3, cols = 2; - CentersHolder ch = new CentersHolder(rows, cols); - - INDArray row0 = Nd4j.createFromArray(new double[]{16.4000, 17.1200}); - INDArray row1 = Nd4j.createFromArray(new double[]{45.8000, 54.2300}); - INDArray row2 = Nd4j.createFromArray(new double[]{95.9348, 94.1990}); - - ch.addCenter(row0); - ch.addCenter(row1); - ch.addCenter(row2); - - double[] data = new double[]{ - 15, 16, - 16, 18.5, - 17, 20.2, - 16.4, 17.12, - 17.23, 18.12, - 43, 43, - 44.43, 45.212, - 45.8, 54.23, - 46.313, 43.123, - 50.21, 46.3, - 99, 99.22, - 100.32, 98.123, - 100.32, 97.423, - 102, 93.23, - 102.23, 94.23 - }; - - INDArray pointData = Nd4j.createFromArray(data); - List points = Point.toPoints(pointData.reshape(15,2)); - - for (int i = 0 ; i < points.size(); ++i) { - INDArray dist = ch.getMinDistances(points.get(i), Distance.EUCLIDEAN); - System.out.println("Point: " + points.get(i).getArray()); - System.out.println("Centers: " + ch.getCenters()); - System.out.println("Distance: " + dist); - System.out.println(); - } - } - - @Test - public void testInitClusters() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - { - KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true); - - double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8}, - {8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8}, - {1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8}, - {2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8}, - {3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8}, - {4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8}, - {4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8}, - {5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8}, - {6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}}; - INDArray data = Nd4j.createFromArray(dataArray); - List points = Point.toPoints(data); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - - double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8}; - double[] centroid2 = {1000000.0, 2.8E7, 5.5E7, 8.2E7}; - double[] centroid3 = {5.95E8, 6.22e8, 6.49e8, 6.76e8}; - double[] centroid4 = {3.79E8, 4.06E8, 4.33E8, 4.6E8}; - double[] centroid5 = {5.5E7, 8.2E7, 1.09E8, 1.36E8}; - - assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java deleted file mode 100644 index 105dd368a..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Random; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class RandomProjectionLSHTest extends BaseDL4JTest { - - int hashLength = 31; - int numTables = 2; - int intDimensions = 13; - - RandomProjectionLSH rpLSH; - INDArray e1; - INDArray inputs; - - @Before - public void setUp() { - Nd4j.getRandom().setSeed(12345); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - rpLSH = new RandomProjectionLSH(hashLength, numTables, intDimensions, 0.1f); - inputs = Nd4j.rand(DataType.DOUBLE, 100, intDimensions); - e1 = Nd4j.ones(DataType.DOUBLE, 1, intDimensions); - } - - - @After - public void tearDown() { inputs = null; } - - @Test - public void testEntropyDims(){ - assertArrayEquals(new long[]{numTables, intDimensions}, rpLSH.entropy(e1).shape()); - } - - @Test - public void testHashDims(){ - assertArrayEquals(new long[]{1, hashLength}, rpLSH.hash(e1).shape()); - } - - @Test - public void testHashDimsMultiple(){ - INDArray data = Nd4j.ones(1, intDimensions); - assertArrayEquals(new long[]{1, hashLength}, rpLSH.hash(data).shape()); - - data = Nd4j.ones(100, intDimensions); - assertArrayEquals(new long[]{100, hashLength}, rpLSH.hash(data).shape()); - } - - @Test - public void testSigNums(){ - assertEquals(1.0f, rpLSH.hash(e1).aminNumber().floatValue(),1e-3f); - } - - - @Test - public void testIndexDims(){ - rpLSH.makeIndex(Nd4j.rand(100, intDimensions)); - assertArrayEquals(new long[]{100, hashLength}, rpLSH.index.shape()); - } - - - @Test - public void testGetRawBucketOfDims(){ - rpLSH.makeIndex(inputs); - assertArrayEquals(new long[]{100}, rpLSH.rawBucketOf(e1).shape()); - } - - @Test - public void testRawBucketOfReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - assertEquals(1.0f, rpLSH.rawBucketOf(row).maxNumber().floatValue(), 1e-3f); - } - - @Test - public void testBucketDims(){ - rpLSH.makeIndex(inputs); - assertArrayEquals(new long[]{100}, rpLSH.bucket(e1).shape()); - } - - @Test - public void testBucketReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - assertEquals(1.0f, rpLSH.bucket(row).maxNumber().floatValue(), 1e-3f); - } - - - @Test - public void testBucketDataReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray bucketData = rpLSH.bucketData(row); - - assertEquals(intDimensions, bucketData.shape()[1]); - assertTrue(1 <= bucketData.shape()[0]); - } - - @Test - public void testBucketDataReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray bucketData = rpLSH.bucketData(row); - - INDArray res = Nd4j.zeros(DataType.BOOL, bucketData.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(bucketData, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one bucket content to be the query %s, but found %s", row, rpLSH.bucket(row)), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - - - @Test - public void testSearchReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray searchResults = rpLSH.search(row, 10.0f); - - assertTrue( - String.format("Expected the search to return at least one result, the query %s but found %s yielding %d results", row, searchResults, searchResults.shape()[0]), - searchResults.shape()[0] >= 1); - } - - - @Test - public void testSearchReflexive() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - - INDArray searchResults = rpLSH.search(row, 10.0f); - - - INDArray res = Nd4j.zeros(DataType.BOOL, searchResults.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(searchResults, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one search result to be the query %s, but found %s", row, searchResults), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - - - - @Test - public void testANNSearchReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray searchResults = rpLSH.search(row, 100); - - assertTrue( - String.format("Expected the search to return at least one result, the query %s but found %s yielding %d results", row, searchResults, searchResults.shape()[0]), - searchResults.shape()[0] >= 1); - } - - - @Test - public void testANNSearchReflexive() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx).reshape(1, intDimensions); - - INDArray searchResults = rpLSH.search(row, 100); - - - INDArray res = Nd4j.zeros(DataType.BOOL, searchResults.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(searchResults, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one search result to be the query %s, but found %s", row, searchResults), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java deleted file mode 100644 index 0cb77bd1d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class QuadTreeTest extends BaseDL4JTest { - - @Test - public void testQuadTree() { - INDArray n = Nd4j.ones(3, 2); - n.slice(1).addi(1); - n.slice(2).addi(2); - QuadTree quadTree = new QuadTree(n); - assertEquals(n.rows(), quadTree.getCumSize()); - assertTrue(quadTree.isCorrect()); - - - - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java deleted file mode 100644 index abb55a7fd..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.List; - -import static org.junit.Assert.*; - -public class RPTreeTest extends BaseDL4JTest { - - @Before - public void setUp() { - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - } - - - @Test - public void testRPTree() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(150,150); - RPTree rpTree = new RPTree(784,50); - DataSet d = mnist.next(); - NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); - normalizerStandardize.fit(d); - normalizerStandardize.transform(d.getFeatures()); - INDArray data = d.getFeatures(); - rpTree.buildTree(data); - assertEquals(4,rpTree.getLeaves().size()); - assertEquals(0,rpTree.getRoot().getDepth()); - - List candidates = rpTree.getCandidates(data.getRow(0)); - assertFalse(candidates.isEmpty()); - assertEquals(10,rpTree.query(data.slice(0),10).length()); - System.out.println(candidates.size()); - - rpTree.addNodeAtIndex(150,data.getRow(0)); - - } - - @Test - public void testFindSelf() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(100, 6000); - NormalizerMinMaxScaler minMaxNormalizer = new NormalizerMinMaxScaler(0, 1); - minMaxNormalizer.fit(mnist); - DataSet d = mnist.next(); - minMaxNormalizer.transform(d.getFeatures()); - RPForest rpForest = new RPForest(100, 100, "euclidean"); - rpForest.fit(d.getFeatures()); - for (int i = 0; i < 10; i++) { - INDArray indexes = rpForest.queryAll(d.getFeatures().slice(i), 10); - assertEquals(i,indexes.getInt(0)); - } - } - - @Test - public void testRpTreeMaxNodes() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(150,150); - RPForest rpTree = new RPForest(4,4,"euclidean"); - DataSet d = mnist.next(); - NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); - normalizerStandardize.fit(d); - rpTree.fit(d.getFeatures()); - for(RPTree tree : rpTree.getTrees()) { - for(RPNode node : tree.getLeaves()) { - assertTrue(node.getIndices().size() <= rpTree.getMaxSize()); - } - } - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java deleted file mode 100644 index 18ca2ac9d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class RPUtilsTest extends BaseDL4JTest { - - @Test - public void testDistanceComputeBatch() { - INDArray x = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(1, 4); - INDArray y = Nd4j.linspace(1,16,16, Nd4j.dataType()).reshape(4,4); - INDArray result = Nd4j.create(1, 4); - INDArray distances = RPUtils.computeDistanceMulti("euclidean",x,y,result); - INDArray scalarResult = Nd4j.scalar(1.0); - for(int i = 0; i < result.length(); i++) { - double dist = RPUtils.computeDistance("euclidean",x,y.slice(i),scalarResult); - assertEquals(dist,distances.getDouble(i),1e-3); - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java deleted file mode 100644 index 0ac39083b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.util.DataTypeUtil; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; - -import static org.junit.Assert.*; - -/** - * @author Adam Gibson - */ -public class SPTreeTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - - @Before - public void setUp() { - DataTypeUtil.setDTypeForContext(DataType.DOUBLE); - } - - @Test - public void testStructure() { - INDArray data = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); - SpTree tree = new SpTree(data); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - assertEquals(Nd4j.create(new double[]{2.5f, 3.5f, 4.5f}), tree.getCenterOfMass()); - assertEquals(2, tree.getCumSize()); - assertEquals(8, tree.getNumChildren()); - assertTrue(tree.isCorrect()); - } - } - - @Test - public void testComputeEdgeForces() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - double[] aData = new double[]{ - 0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, - 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, - 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, - 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, - 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, - 0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, - 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; - INDArray data = Nd4j.createFromArray(aData).reshape(11,5); - INDArray rows = Nd4j.createFromArray(new int[]{ - 0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - INDArray cols = Nd4j.createFromArray(new int[]{ - 4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - INDArray vals = Nd4j.createFromArray(new double[] - { 0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911}); - SpTree tree = new SpTree(data); - INDArray posF = Nd4j.create(11, 5); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - tree.computeEdgeForces(rows, cols, vals, 11, posF); - } - INDArray expected = Nd4j.createFromArray(new double[]{ -0.08045664291717945, -0.1010737980370276, 0.01793326162563703, 0.16108447776416351, -0.20679423033936287, -0.15788549368713395, 0.02546624825966788, 0.062309466206907055, -0.165806093080134, 0.15266225270841186, 0.17508365896345726, 0.09588570563583201, 0.34124767300538084, 0.14606666020839956, -0.06786563815470595, -0.09326646571247202, -0.19896040730569928, -0.3618837364446506, 0.13946315445146712, -0.04570186310149667, -0.2473462951783839, -0.41362278505023914, -0.1094083777758208, 0.10705807646770374, 0.24462088260113946, 0.21722270026621748, -0.21799892431326567, -0.08205544003080587, -0.11170161709042685, -0.2674768703060442, 0.03617747284043274, 0.16430316252598698, 0.04552845070022399, 0.2593696744801452, 0.1439989190892037, -0.059339471967457376, 0.05460893792863096, -0.0595168036583193, -0.2527693197519917, -0.15850951859835274, -0.2945536856938165, 0.15434659331638875, -0.022910846947667776, 0.23598009757792854, -0.11149279745674007, 0.09670616593772939, 0.11125703954547914, -0.08519984596392606, -0.12779827002328714, 0.23025192887225998, 0.13741473964038722, -0.06193553503816597, -0.08349781586292176, 0.1622156410642145, 0.155975447743472}).reshape(11,5); - for (int i = 0; i < 11; ++i) - assertArrayEquals(expected.getRow(i).toDoubleVector(), posF.getRow(i).toDoubleVector(), 1e-2); - - AtomicDouble sumQ = new AtomicDouble(0.0); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - tree.computeNonEdgeForces(0, 0.5, Nd4j.zeros(5), sumQ); - } - assertEquals(8.65, sumQ.get(), 1e-2); - } - - @Test - //@Ignore - public void testLargeTree() { - int num = isIntegrationTests() ? 100000 : 1000; - StopWatch watch = new StopWatch(); - watch.start(); - INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1); - SpTree tree = new SpTree(arr); - watch.stop(); - System.out.println("Tree of size " + num + " created in " + watch); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java deleted file mode 100644 index 86d34b603..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.SerializationUtils; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -@Slf4j -public class VPTreeSerializationTests extends BaseDL4JTest { - - @Test - public void testSerialization_1() throws Exception { - val points = Nd4j.rand(new int[] {10, 15}); - val treeA = new VPTree(points, true, 2); - - try (val bos = new ByteArrayOutputStream()) { - SerializationUtils.serialize(treeA, bos); - - try (val bis = new ByteArrayInputStream(bos.toByteArray())) { - VPTree treeB = SerializationUtils.deserialize(bis); - - assertEquals(points, treeA.getItems()); - assertEquals(points, treeB.getItems()); - - assertEquals(treeA.getWorkers(), treeB.getWorkers()); - - val row = points.getRow(1).dup('c'); - - val dpListA = new ArrayList(); - val dListA = new ArrayList(); - - val dpListB = new ArrayList(); - val dListB = new ArrayList(); - - treeA.search(row, 3, dpListA, dListA); - treeB.search(row, 3, dpListB, dListB); - - assertTrue(dpListA.size() != 0); - assertTrue(dListA.size() != 0); - - assertEquals(dpListA.size(), dpListB.size()); - assertEquals(dListA.size(), dListB.size()); - - for (int e = 0; e < dpListA.size(); e++) { - val rA = dpListA.get(e).getPoint(); - val rB = dpListB.get(e).getPoint(); - - assertEquals(rA, rB); - } - } - } - } - - - @Test - public void testNewConstructor_1() { - val points = Nd4j.rand(new int[] {10, 15}); - val treeA = new VPTree(points, true, 2); - - val rows = Nd4j.tear(points, 1); - - val list = new ArrayList(); - - int idx = 0; - for (val r: rows) - list.add(new DataPoint(idx++, r)); - - val treeB = new VPTree(list); - - assertEquals(points, treeA.getItems()); - assertEquals(points, treeB.getItems()); - } - - @Test - @Ignore - public void testBigTrees_1() throws Exception { - val list = new ArrayList(); - - for (int e = 0; e < 3200000; e++) { - val dp = new DataPoint(e, Nd4j.rand(new long[] {1, 300})); - } - - log.info("DataPoints created"); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java deleted file mode 100644 index d5ced0cd2..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ /dev/null @@ -1,414 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.joda.time.Duration; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Counter; -import org.nd4j.common.primitives.Pair; - -import java.util.*; - -import static org.junit.Assert.*; - -/** - * @author Anatoly Borisov - */ -@Slf4j -public class VpTreeNodeTest extends BaseDL4JTest { - - - private static class DistIndex implements Comparable { - public double dist; - public int index; - - public int compareTo(DistIndex r) { - return Double.compare(dist, r.dist); - } - } - - @BeforeClass - public static void beforeClass(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @Test - public void testKnnK() { - INDArray arr = Nd4j.randn(10, 5); - VPTree t = new VPTree(arr, false); - List resultList = new ArrayList<>(); - List distances = new ArrayList<>(); - t.search(arr.getRow(0), 5, resultList, distances); - assertEquals(5, resultList.size()); - } - - - @Test - public void testParallel_1() { - int k = 5; - - for (int e = 0; e < 5; e++) { - Nd4j.getRandom().setSeed(7); - INDArray randn = Nd4j.rand(100, 3); - VPTree vpTree = new VPTree(randn, false, 4); - Nd4j.getRandom().setSeed(7); - VPTree vpTreeNoParallel = new VPTree(randn, false, 1); - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - List noParallelResults = new ArrayList<>(); - List noDistances = new ArrayList<>(); - vpTree.search(randn.getRow(0), k, results, distances, true); - vpTreeNoParallel.search(randn.getRow(0), k, noParallelResults, noDistances, true); - - assertEquals("Failed at iteration " + e, k, results.size()); - assertEquals("Failed at iteration " + e, noParallelResults.size(), results.size()); - assertNotEquals(randn.getRow(0, true), results.get(0).getPoint()); - assertEquals("Failed at iteration " + e, noParallelResults, results); - assertEquals("Failed at iteration " + e, noDistances, distances); - } - } - - @Test - public void testParallel_2() { - int k = 5; - - for (int e = 0; e < 5; e++) { - Nd4j.getRandom().setSeed(7); - INDArray randn = Nd4j.rand(100, 3); - VPTree vpTree = new VPTree(randn, false, 4); - Nd4j.getRandom().setSeed(7); - VPTree vpTreeNoParallel = new VPTree(randn, false, 1); - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - List noParallelResults = new ArrayList<>(); - List noDistances = new ArrayList<>(); - vpTree.search(randn.getRow(0), k, results, distances, false); - vpTreeNoParallel.search(randn.getRow(0), k, noParallelResults, noDistances, false); - - assertEquals("Failed at iteration " + e, k, results.size()); - assertEquals("Failed at iteration " + e, noParallelResults.size(), results.size()); - assertEquals(randn.getRow(0, true), results.get(0).getPoint()); - assertEquals("Failed at iteration " + e, noParallelResults, results); - assertEquals("Failed at iteration " + e, noDistances, distances); - } - } - - @Test - public void testReproducibility() { - val results = new ArrayList(); - val distances = new ArrayList(); - Nd4j.getRandom().setSeed(7); - val randn = Nd4j.rand(1000, 100); - - for (int e = 0; e < 10; e++) { - Nd4j.getRandom().setSeed(7); - val vpTree = new VPTree(randn, false, 1); - - val cresults = new ArrayList(); - val cdistances = new ArrayList(); - vpTree.search(randn.getRow(0), 5, cresults, cdistances); - - if (e == 0) { - results.addAll(cresults); - distances.addAll(cdistances); - } else { - assertEquals("Failed at iteration " + e, results, cresults); - assertEquals("Failed at iteration " + e, distances, cdistances); - } - } - } - - @Test - public void knnManualRandom() { - knnManual(Nd4j.randn(3, 5)); - } - - @Test - public void knnManualNaturals() { - knnManual(generateNaturalsMatrix(20, 2)); - } - - public static void knnManual(INDArray arr) { - Nd4j.getRandom().setSeed(7); - VPTree t = new VPTree(arr, false); - int k = 1; - int m = arr.rows(); - for (int targetIndex = 0; targetIndex < m; targetIndex++) { - // Do an exhaustive search - TreeSet s = new TreeSet<>(); - INDArray query = arr.getRow(targetIndex, true); - - Counter counter = new Counter<>(); - for (int j = 0; j < m; j++) { - double d = t.distance(query, (arr.getRow(j, true))); - counter.setCount(j, (float) d); - - } - - PriorityQueue> pq = counter.asReversedPriorityQueue(); - // keep closest k - for (int i = 0; i < k; i++) { - Pair di = pq.poll(); - System.out.println("exhaustive d=" + di.getFirst()); - s.add(di.getFirst()); - } - - // Check what VPTree gives for results - List results = new ArrayList<>(); - VPTreeFillSearch fillSearch = new VPTreeFillSearch(t, k, query); - fillSearch.search(); - results = fillSearch.getResults(); - - //List items = t.getItems(); - TreeSet resultSet = new TreeSet<>(); - - // keep k in a set - for (int i = 0; i < k; ++i) { - DataPoint result = results.get(i); - int r = result.getIndex(); - resultSet.add(r); - } - - - - // check - for (int r : resultSet) { - INDArray expectedResult = arr.getRow(r, true); - if (!s.contains(r)) { - fillSearch = new VPTreeFillSearch(t, k, query); - fillSearch.search(); - results = fillSearch.getResults(); - } - assertTrue(String.format( - "VPTree result" + " %d is not in the " + "closest %d " + " " + "from the exhaustive" - + " search with query point %s and " - + "result %s and target not found %s", - r, k, query.toString(), results.toString(), expectedResult.toString()), s.contains(r)); - } - - } - } - - @Test - public void vpTreeTest() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - List add = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(Nd4j.create(new double[] {50, 50}), 1, add, distances); - DataPoint assertion = add.get(0); - assertEquals(new DataPoint(0, Nd4j.create(new double[] {55, 55}).reshape(1,2)), assertion); - - tree.search(Nd4j.create(new double[] {61, 61}), 2, add, distances, false); - assertion = add.get(0); - assertEquals(Nd4j.create(new double[] {60, 60}).reshape(1,2), assertion.getPoint()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest2() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(1, 10), 2, new ArrayList(), new ArrayList()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest3() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(2, 10), 2, new ArrayList(), new ArrayList()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest4() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(2, 10, 10), 2, new ArrayList(), new ArrayList()); - } - - public static INDArray generateNaturalsMatrix(int nrows, int ncols) { - INDArray col = Nd4j.arange(0, nrows).reshape(nrows, 1).castTo(DataType.DOUBLE); - INDArray points = Nd4j.create(DataType.DOUBLE, nrows, ncols); - if (points.isColumnVectorOrScalar()) - points = col.dup(); - else { - for (int i = 0; i < ncols; i++) - points.putColumn(i, col); - } - return points; - } - - @Test - public void testVPSearchOverNaturals1D() throws Exception { - testVPSearchOverNaturalsPD(20, 1, 5); - } - - @Test - public void testVPSearchOverNaturals2D() throws Exception { - testVPSearchOverNaturalsPD(20, 2, 5); - } - - @Test - public void testTreeOrder() { - - int N = 10, dim = 1; - INDArray dataset = Nd4j.randn(N, dim); - double[] rawData = dataset.toDoubleVector(); - Arrays.sort(dataset.toDoubleVector()); - dataset = Nd4j.createFromArray(rawData).reshape(1,N); - - List points = new ArrayList<>(); - - for (int i = 0; i < rawData.length; ++i) { - points.add(new DataPoint(i, Nd4j.create(new double[]{rawData[i]}))); - } - - VPTree tree = new VPTree(points, "euclidean"); - INDArray points1 = tree.getItems(); - assertEquals(dataset, points1); - } - - @Test - public void testNearestNeighbors() { - - List points = new ArrayList<>(); - - points.add(new DataPoint(0, Nd4j.create(new double[] {0.83494041, 1.70294823, -1.34172191, 0.02350972, - -0.87519361, 0.64401935, -0.5634212, -1.1274308, - 0.19245948, -0.11349026}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {-0.41115537, -0.7686138, -0.67923172, 1.01638281, - 0.04390801, 0.29753166, 0.78915771, -0.13564866, - -1.06053692, -0.15953041}))); - - VPTree tree = new VPTree(points, "euclidean"); - - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - - final int k = 1; - double[] input = new double[]{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}; - tree.search(Nd4j.createFromArray(input), k, results, distances); - assertEquals(k, distances.size()); - assertEquals(2.7755637844503016, distances.get(0), 1e-5); - - double[] results_pattern = new double[]{-0.41115537, -0.7686138 , -0.67923172, 1.01638281, 0.04390801, - 0.29753166, 0.78915771, -0.13564866, -1.06053692, -0.15953041}; - for (int i = 0; i < results_pattern.length; ++i) { - assertEquals(results_pattern[i], results.get(0).getPoint().getDouble(i), 1e-5); - } - } - - @Test - public void performanceTest() { - final int dim = 300; - final int rows = 8000; - final int k = 5; - - INDArray inputArrray = Nd4j.linspace(DataType.DOUBLE, 0.0, 1.0, rows * dim).reshape(rows, dim); - - //INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, 200000, dim); - long start = System.currentTimeMillis(); - VPTree tree = new VPTree(inputArrray, "euclidean"); - long end = System.currentTimeMillis(); - Duration duration = new Duration(start, end); - System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds()); - - double[] input = new double[dim]; - for (int i = 0; i < dim; ++i) { - input[i] = 119; - } - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - start = System.currentTimeMillis(); - tree.search(Nd4j.createFromArray(input), k, results, distances); - end = System.currentTimeMillis(); - duration = new Duration(start, end); - System.out.println("Elapsed time for tree search " + duration.getStandardSeconds()); - assertEquals(1590.2987519949422, distances.get(0), 1e-4); - } - - public static void testVPSearchOverNaturalsPD(int nrows, int ncols, int K) throws Exception { - final int queryPoint = 12; - - INDArray points = generateNaturalsMatrix(nrows, ncols); - INDArray query = Nd4j.zeros(DataType.DOUBLE, 1, ncols); - for (int i = 0; i < ncols; i++) - query.putScalar(0, i, queryPoint); - - INDArray trueResults = Nd4j.zeros(DataType.DOUBLE, K, ncols); - for (int j = 0; j < K; j++) { - int pt = queryPoint - K / 2 + j; - for (int i = 0; i < ncols; i++) - trueResults.putScalar(j, i, pt); - } - - VPTree tree = new VPTree(points, "euclidean", 1, false); - - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(query, K, results, distances, false); - int dimensionToSort = 0; - - INDArray sortedResults = Nd4j.zeros(DataType.DOUBLE, K, ncols); - int i = 0; - for (DataPoint p : results) { - sortedResults.putRow(i++, p.getPoint()); - } - - sortedResults = Nd4j.sort(sortedResults, dimensionToSort, true); - assertTrue(trueResults.equalsWithEps(sortedResults, 1e-5)); - - VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, K, query); - fillSearch.search(); - results = fillSearch.getResults(); - sortedResults = Nd4j.zeros(DataType.FLOAT, K, ncols); - i = 0; - for (DataPoint p : results) - sortedResults.putRow(i++, p.getPoint()); - INDArray[] sortedWithIndices = Nd4j.sortWithIndices(sortedResults, dimensionToSort, true);; - sortedResults = sortedWithIndices[1]; - assertEquals(trueResults.sumNumber().doubleValue(), sortedResults.sumNumber().doubleValue(), 1e-5); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml deleted file mode 100644 index b95ab2c73..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-parent - pom - - deeplearning4j-nearestneighbors-parent - - - deeplearning4j-nearestneighbor-server - nearestneighbor-core - deeplearning4j-nearestneighbors-client - deeplearning4j-nearestneighbors-model - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index acd187417..c17e67df1 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -36,7 +36,6 @@ pom DeepLearning4j - http://deeplearning4j.org/ DeepLearning for java @@ -59,7 +58,6 @@ deeplearning4j-modelimport deeplearning4j-modelexport-solr deeplearning4j-zoo - deeplearning4j-nearestneighbors-parent deeplearning4j-data deeplearning4j-manifold dl4j-integration-tests diff --git a/libnd4j/CMakeLists.txt b/libnd4j/CMakeLists.txt index 695acec35..1e2633e07 100755 --- a/libnd4j/CMakeLists.txt +++ b/libnd4j/CMakeLists.txt @@ -1,8 +1,11 @@ cmake_minimum_required(VERSION 3.15) project(libnd4j) -set(CMAKE_VERBOSE_MAKEFILE OFF) +set(CMAKE_VERBOSE_MAKEFILE ON) + + +set (CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake") +message("CMAKE MODULE PATH ${CMAKE_MODULE_PATH}") -set(CMAKE_MODULE_PATH "${CMAKE_SOURCE_DIR}/cmake" ${CMAKE_MODULE_PATH}) #ensure we create lib files set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS OFF) @@ -18,8 +21,99 @@ set(FLATBUFFERS_BUILD_FLATC "OFF" CACHE STRING "Hack to disable flatc build" FOR set(CMAKE_CXX_STANDARD 11) +#/////////////////////////////////////////////////////////////////////////////// +# genCompilation: Generates cpp, cu files +# INPUT: +# $FILE_ITEM template-configuration that utilizes libnd4j type, macros helpers +# defined inside { include/types/types.h, include/system/type_boilerplate.h} +# OUTPUT: +# $CUSTOMOPS_GENERIC_SOURCES generated files will be added into this List +#//////////////////////////////////////////////////////////////////////////////// +# A simple template-configuration file example: +# // hints and defines what types will be generated +# #cmakedefine LIBND4J_TYPE_GEN +# #cmakedefine FLOAT_TYPE_GEN +# // below if defines blocks are needed for correctly handling multiple types +# #if defined(LIBND4J_TYPE_GEN) +# BUILD_DOUBLE_TEMPLATE(template void someFunc, (arg_list,..), +# LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); +# #endif +# #if defined(FLOAT_TYPE_GEN) +# BUILD_SINGLE_TEMPLATE(template class SomeClass,, FLOAT_TYPES_@FL_TYPE_INDEX@); +# #endif +#//////////////////////////////////////////////////////////////////////////////// + +set_property(GLOBAL PROPERTY JOB_POOLS one_jobs=1 two_jobs=2) + + + + +function(genCompilation FILE_ITEM) + get_filename_component(FILE_ITEM_WE ${FL_ITEM} NAME_WE) + + set(EXTENSION "cpp") + + if(FL_ITEM MATCHES "cu.in$") + set(EXTENSION "cu") + endif() + + file(READ ${FL_ITEM} CONTENT_FL) + #check content for types + + #set all to false + set (FLOAT_TYPE_GEN 0) + set (INT_TYPE_GEN 0) + set (LIBND4J_TYPE_GEN 0) + set (PAIRWISE_TYPE_GEN 0) + set (RANGE_STOP -1) + + string(REGEX MATCHALL "#cmakedefine[ \t]+[^_]+_TYPE_GEN" TYPE_MATCHES ${CONTENT_FL}) + + foreach(TYPEX ${TYPE_MATCHES}) + set(STOP -1) + if(TYPEX MATCHES "INT_TYPE_GEN$") + set (INT_TYPE_GEN 1) + set(STOP 7) + endif() + if(TYPEX MATCHES "LIBND4J_TYPE_GEN$") + set (LIBND4J_TYPE_GEN 1) + set(STOP 9) + endif() + if(TYPEX MATCHES "FLOAT_TYPE_GEN$") + set (FLOAT_TYPE_GEN 1) + set(STOP 3) + endif() + if(TYPEX MATCHES "PAIRWISE_TYPE_GEN$") + set (PAIRWISE_TYPE_GEN 1) + set(STOP 12) + endif() + if(STOP GREATER RANGE_STOP) + set(RANGE_STOP ${STOP}) + endif() + + endforeach() + + if(RANGE_STOP GREATER -1) + foreach(FL_TYPE_INDEX RANGE 0 ${RANGE_STOP}) + # set OFF if the index is above + if(FL_TYPE_INDEX GREATER 3) + set (FLOAT_TYPE_GEN 0) + endif() + if(FL_TYPE_INDEX GREATER 7) + set (INT_TYPE_GEN 0) + endif() + if(FL_TYPE_INDEX GREATER 9) + set (LIBND4J_TYPE_GEN 0) + endif() + set(GENERATED_SOURCE "${CMAKE_BINARY_DIR}/compilation_units/${FILE_ITEM_WE}_${FL_TYPE_INDEX}.${EXTENSION}") + configure_file( "${FL_ITEM}" "${GENERATED_SOURCE}" @ONLY) + LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${GENERATED_SOURCE} ) + endforeach() + endif() + + set(CUSTOMOPS_GENERIC_SOURCES ${CUSTOMOPS_GENERIC_SOURCES} PARENT_SCOPE) +endfunction() -include(GenCompilation) if (SD_CUDA) enable_language(CUDA) @@ -42,6 +136,7 @@ endif() # -fsanitize=address # -fsanitize=leak if (SD_ANDROID_BUILD) + set_property(GLOBAL PROPERTY JOB_POOLS one_job=1 two_jobs=2) set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3 -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else -D_RELEASE=true") set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -O0 -g -fPIC -Wno-braced-scalar-init -Wno-delete-non-virtual-dtor -Wno-unused-command-line-argument -Wno-dangling-else") elseif (APPLE) @@ -315,13 +410,13 @@ elseif(DISTRIBUTION STREQUAL "CentOS") else() set(CPACK_RPM_PACKAGE_ARCHITECTURE "i686") endif() - set(CPACK_PACKAGE_CONTACT "raver119") + set(CPACK_PACKAGE_CONTACT "agibsonccc") set(CPACK_RPM_PACKAGE_GROUP "Development/Tools") set(CPACK_RPM_PACKAGE_LICENSE "Apache-2.0") set(CPACK_RPM_PACKAGE_SUGGESTS "cuda") # Build deps: atlas blas lapack cmake3 devtoolset-4-gcc devtoolset-4-gcc-c++ set(CPACK_RPM_PACKAGE_REQUIRES "") - set(CPACK_RPM_PACKAGE_URL "https://github.com/deeplearning4j/libnd4j") + set(CPACK_RPM_PACKAGE_URL "https://github.com/eclipse/deeplearning4j/libnd4j") set(CPACK_GENERATOR "RPM") set(CPACK_PACKAGE_FILE_NAME ${CPACK_PACKAGE_NAME}-${CPACK_PACKAGE_VERSION}.fc${RELEASE}.${CPACK_RPM_PACKAGE_ARCHITECTURE}) set(CPACK_RPM_POST_INSTALL_SCRIPT_FILE "${CMAKE_CURRENT_SOURCE_DIR}/cmake/postinst") diff --git a/libnd4j/README.md b/libnd4j/README.md index 4dbb63ba9..ffeb5c2f3 100644 --- a/libnd4j/README.md +++ b/libnd4j/README.md @@ -45,10 +45,9 @@ You can find the same information for the older Toolkit versions [in the CUDA ar [Download the NDK](https://developer.android.com/ndk/downloads/), extract it somewhere, and execute the following commands, replacing `android-xxx` with either `android-arm` or `android-x86`: ```bash -git clone https://github.com/deeplearning4j/libnd4j -git clone https://github.com/deeplearning4j/nd4j +git clone https://github.com/eclipse/deeplearning4j export ANDROID_NDK=/path/to/android-ndk/ -cd libnd4j +cd deeplearning4j/libnd4j bash buildnativeoperations.sh -platform android-xxx cd ../nd4j mvn clean install -Djavacpp.platform=android-xxx -DskipTests -pl '!:nd4j-cuda-9.0,!:nd4j-cuda-9.0-platform,!:nd4j-tests' diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 15fd70c69..d65c9660a 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -176,6 +176,8 @@ if(SD_CUDA) set(EXPM " -D__ND4J_EXPERIMENTAL__=true") endif() + + # the only difference for debug mode here is host/device debug symbols set(CMAKE_CUDA_FLAGS_DEBUG " -G -g") @@ -185,6 +187,18 @@ if(SD_CUDA) set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler=-fPIC") endif() + if(WIN32) + message("In windows, setting cublas library and cusolver library") + if(NOT DEFINED CUDA_cublas_LIBRARY) + set(CUDA_cublas_LIBRARY ${CUDA_HOME}/lib/x64/cublas.lib) + endif() + + if(NOT DEFINED CUDA_cusolver_LIBRARY) + set(CUDA_cusolver_LIBRARY ${CUDA_HOME}/lib/x64/cusolver.lib) + endif() + endif() + + string( TOLOWER "${COMPUTE}" COMPUTE_CMP ) if ("${COMPUTE_CMP}" STREQUAL "all") CUDA_SELECT_NVCC_ARCH_FLAGS(CUDA_ARCH_FLAGS "Common") @@ -343,16 +357,27 @@ elseif(SD_CPU) message("CPU BLAS") add_definitions(-D__CPUBLAS__=true) + add_library(samediff_obj OBJECT ${LEGACY_SOURCES} ${LOOPS_SOURCES} ${HELPERS_SOURCES} ${EXEC_SOURCES} ${ARRAY_SOURCES} ${TYPES_SOURCES} ${MEMORY_SOURCES} ${GRAPH_SOURCES} ${CUSTOMOPS_SOURCES} ${EXCEPTIONS_SOURCES} ${INDEXING_SOURCES} ${CUSTOMOPS_MKLDNN_SOURCES} ${CUSTOMOPS_ARMCOMPUTE_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${OPS_SOURCES} ${PERF_SOURCES}) + if(IOS) add_library(${SD_LIBRARY_NAME} STATIC $) else() # build shared library by default or when it's explicitly requested if(NOT SD_STATIC_LIB OR SD_SHARED_LIB) add_library(${SD_LIBRARY_NAME} SHARED $) + if(ANDROID) + # See: https://www.scivision.dev/cmake-ninja-job-pool-limited-memory/ + # See: https://cmake.org/cmake/help/v3.0/command/cmake_host_system_information.html + # See: https://cmake.org/cmake/help/latest/prop_gbl/JOB_POOLS.html + cmake_host_system_information(RESULT _logical_cores QUERY NUMBER_OF_LOGICAL_CORES) + if(_logical_cores LESS 4) + set_target_properties(${SD_LIBRARY_NAME} PROPERTIES JOB_POOL_COMPILE one_jobs) + endif() + endif() endif() if (SD_STATIC_LIB AND SD_SHARED_LIB) diff --git a/libnd4j/buildnativeoperations.sh b/libnd4j/buildnativeoperations.sh index bcb664a76..ec20f400e 100755 --- a/libnd4j/buildnativeoperations.sh +++ b/libnd4j/buildnativeoperations.sh @@ -86,7 +86,10 @@ VERBOSE="false" VERBOSE_ARG="VERBOSE=1" HELPER= CHECK_VECTORIZATION="OFF" +SYS_ROOT= +EXTRA_LINK_FLAGS= NAME= +EXTRA_CUDA_FLAGS= while [[ $# -gt 0 ]] do key="$1" @@ -399,6 +402,11 @@ if [ -z "$BUILD" ]; then fi +if [ -z "$SYS_ROOT" ]; then + export SYS_ROOT="" +fi + + if [ -z "$CHIP" ]; then CHIP="cpu" fi @@ -411,9 +419,7 @@ if [ -z "$PACKAGING" ]; then PACKAGING="none" fi -if [ -z "$COMPUTE" ]; then - COMPUTE="all" -fi + if [ "$CHIP_EXTENSION" == "avx512" ] || [ "$ARCH" == "avx512" ]; then CHIP_EXTENSION="avx512" @@ -430,6 +436,14 @@ if [ -z "$ARCH" ]; then ARCH="x86-64" fi +if [ -z "$COMPUTE" ]; then + if [ "$ARCH" == "x86-64" ]; then + COMPUTE="5.0 5.2 5.3 6.0 6.2 8.0" + else + COMPUTE="5.0 5.2 5.3 6.0 6.2" + fi +fi + OPERATIONS_ARG= if [ -z "$OPERATIONS" ]; then @@ -503,6 +517,13 @@ if [ "$TESTS" == "true" ]; then TESTS_ARG="-DSD_BUILD_TESTS=ON" fi + +if [ "$SYS_ROOT" != "" ]; then + EXTRA_SYSROOT="-DCMAKE_SYSROOT=$SYS_ROOT" + else + EXTRA_SYSROOT="" +fi + ARCH_ARG="-DSD_ARCH=$ARCH -DSD_EXTENSION=$CHIP_EXTENSION" CUDA_COMPUTE="-DCOMPUTE=\"$COMPUTE\"" @@ -511,6 +532,16 @@ if [ "$CHIP" == "cuda" ] && [ -n "$CHIP_VERSION" ]; then case $OS in linux*) export CUDA_PATH="/usr/local/cuda-$CHIP_VERSION/" + # Cross compilation for jetson nano + if [ "$ARCH" != "x86-64" ]; then + if [ "$ARCH" == "armv8-a" ]; then + export EXTRA_CUDA_FLAGS="-DCUDA_TARGET_CPU_ARCH=AARCH64" + else + export EXTRA_CUDA_FLAGS="-DCUDA_TARGET_CPU_ARCH=ARM" + fi + else + export EXTRA_CUDA_FLAGS="" + fi ;; macosx*) export CUDA_PATH="/Developer/NVIDIA/CUDA-$CHIP_VERSION/" @@ -578,6 +609,13 @@ else IFS=' ' fi +LINKER_FLAGS="" +if [ "$EXTRA_LINK_FLAGS" != "" ]; then + LINKER_FLAGS="-DCMAKE_CXX_LINK_FLAGS=$EXTRA_LINK_FLAGS -DCMAKE_EXE_LINKER_FLAGS=$EXTRA_LINK_FLAGS -DCMAKE_CUDA_FLAGS=$EXTRA_LINK_FLAGS" +fi + + + echo PACKAGING = "${PACKAGING}" echo BUILD = "${BUILD}" echo CHIP = "${CHIP}" @@ -594,9 +632,12 @@ echo NAME = "${NAME_ARG}" echo OPENBLAS_PATH = "$OPENBLAS_PATH" echo CHECK_VECTORIZATION = "$CHECK_VECTORIZATION" echo HELPERS = "$HELPERS" +echo EXTRA_LINK_FLAGS = "$EXTRA_LINK_FLAGS" +echo EXTRA_CUDA_FLAGS = "$EXTRA_CUDA_FLAGS" +echo EXTRA_SYSROOT = "$EXTRA_SYSROOT" mkbuilddir pwd -eval "$CMAKE_COMMAND" "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DSD_CHECK_VECTORIZATION="${CHECK_VECTORIZATION}" "$HELPERS" "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. +eval "$CMAKE_COMMAND" "$EXTRA_SYSROOT" "$LINKER_FLAGS" "$EXTRA_CUDA_FLAGS" "$BLAS_ARG" "$ARCH_ARG" "$NAME_ARG" -DSD_CHECK_VECTORIZATION="${CHECK_VECTORIZATION}" "$HELPERS" "$SHARED_LIBS_ARG" "$MINIFIER_ARG" "$OPERATIONS_ARG" "$BUILD_TYPE" "$PACKAGING_ARG" "$EXPERIMENTAL_ARG" "$TESTS_ARG" "$CUDA_COMPUTE" -DOPENBLAS_PATH="$OPENBLAS_PATH" -DDEV=FALSE -DCMAKE_NEED_RESPONSE=YES -DMKL_MULTI_THREADED=TRUE ../.. if [ "$PARALLEL" == "true" ]; then MAKE_ARGUMENTS="$MAKE_ARGUMENTS -j $MAKEJ" diff --git a/libnd4j/cmake/GenCompilation.cmake b/libnd4j/cmake/GenCompilation.cmake index 0aca627c6..0232acfb2 100644 --- a/libnd4j/cmake/GenCompilation.cmake +++ b/libnd4j/cmake/GenCompilation.cmake @@ -17,90 +17,3 @@ # SPDX-License-Identifier: Apache-2.0 ################################################################################ -#/////////////////////////////////////////////////////////////////////////////// -# genCompilation: Generates cpp, cu files -# INPUT: -# $FILE_ITEM template-configuration that utilizes libnd4j type, macros helpers -# defined inside { include/types/types.h, include/system/type_boilerplate.h} -# OUTPUT: -# $CUSTOMOPS_GENERIC_SOURCES generated files will be added into this List -#//////////////////////////////////////////////////////////////////////////////// -# A simple template-configuration file example: -# // hints and defines what types will be generated -# #cmakedefine LIBND4J_TYPE_GEN -# #cmakedefine FLOAT_TYPE_GEN -# // below if defines blocks are needed for correctly handling multiple types -# #if defined(LIBND4J_TYPE_GEN) -# BUILD_DOUBLE_TEMPLATE(template void someFunc, (arg_list,..), -# LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); -# #endif -# #if defined(FLOAT_TYPE_GEN) -# BUILD_SINGLE_TEMPLATE(template class SomeClass,, FLOAT_TYPES_@FL_TYPE_INDEX@); -# #endif -#//////////////////////////////////////////////////////////////////////////////// - -function(genCompilation FILE_ITEM) - get_filename_component(FILE_ITEM_WE ${FL_ITEM} NAME_WE) - - set(EXTENSION "cpp") - - if(FL_ITEM MATCHES "cu.in$") - set(EXTENSION "cu") - endif() - - file(READ ${FL_ITEM} CONTENT_FL) - #check content for types - - #set all to false - set (FLOAT_TYPE_GEN 0) - set (INT_TYPE_GEN 0) - set (LIBND4J_TYPE_GEN 0) - set (PAIRWISE_TYPE_GEN 0) - set (RANGE_STOP -1) - - string(REGEX MATCHALL "#cmakedefine[ \t]+[^_]+_TYPE_GEN" TYPE_MATCHES ${CONTENT_FL}) - - foreach(TYPEX ${TYPE_MATCHES}) - set(STOP -1) - if(TYPEX MATCHES "INT_TYPE_GEN$") - set (INT_TYPE_GEN 1) - set(STOP 7) - endif() - if(TYPEX MATCHES "LIBND4J_TYPE_GEN$") - set (LIBND4J_TYPE_GEN 1) - set(STOP 9) - endif() - if(TYPEX MATCHES "FLOAT_TYPE_GEN$") - set (FLOAT_TYPE_GEN 1) - set(STOP 3) - endif() - if(TYPEX MATCHES "PAIRWISE_TYPE_GEN$") - set (PAIRWISE_TYPE_GEN 1) - set(STOP 12) - endif() - if(STOP GREATER RANGE_STOP) - set(RANGE_STOP ${STOP}) - endif() - - endforeach() - - if(RANGE_STOP GREATER -1) - foreach(FL_TYPE_INDEX RANGE 0 ${RANGE_STOP}) - # set OFF if the index is above - if(FL_TYPE_INDEX GREATER 3) - set (FLOAT_TYPE_GEN 0) - endif() - if(FL_TYPE_INDEX GREATER 7) - set (INT_TYPE_GEN 0) - endif() - if(FL_TYPE_INDEX GREATER 9) - set (LIBND4J_TYPE_GEN 0) - endif() - set(GENERATED_SOURCE "${CMAKE_BINARY_DIR}/compilation_units/${FILE_ITEM_WE}_${FL_TYPE_INDEX}.${EXTENSION}") - configure_file( "${FL_ITEM}" "${GENERATED_SOURCE}" @ONLY) - LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${GENERATED_SOURCE} ) - endforeach() - endif() - - set(CUSTOMOPS_GENERIC_SOURCES ${CUSTOMOPS_GENERIC_SOURCES} PARENT_SCOPE) -endfunction() \ No newline at end of file diff --git a/libnd4j/nano_build.sh b/libnd4j/nano_build.sh new file mode 100644 index 000000000..73860689a --- /dev/null +++ b/libnd4j/nano_build.sh @@ -0,0 +1,234 @@ +#!/usr/bin/env bash +# +# /* ****************************************************************************** +# * +# * +# * This program and the accompanying materials are made available under the +# * terms of the Apache License, Version 2.0 which is available at +# * https://www.apache.org/licenses/LICENSE-2.0. +# * +# * See the NOTICE file distributed with this work for additional +# * information regarding copyright ownership. +# * Unless required by applicable law or agreed to in writing, software +# * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# * License for the specific language governing permissions and limitations +# * under the License. +# * +# * SPDX-License-Identifier: Apache-2.0 +# ******************************************************************************/ +# + +function message { + echo "BUILDER:::: ${@}" +} +if [ -z "${BUILD_USING_MAVEN}" ]; then export BUILD_USING_MAVEN=; fi +if [ -z "${CURRENT_TARGET}" ]; then export CURRENT_TARGET=arm32; fi +if [ -z "${HAS_ARMCOMPUTE}" ]; then export ARMCOMPUTE_DEBUG=1; fi +if [ -z "${ARMCOMPUTE_DEBUG}" ]; then export HAS_ARMCOMPUTE=1; fi +if [ -z "${ARMCOMPUTE_TAG}" ]; then export ARMCOMPUTE_TAG=v20.05; fi +if [ -z "${LIBND4J_BUILD_MODE}" ]; then export LIBND4J_BUILD_MODE=Release; fi +if [ -z "${ANDROID_VERSION}" ]; then export ANDROID_VERSION=21; fi +if [ -z "${HAS_ARMCOMPUTE}" ]; then export HAS_ARMCOMPUTE=1; fi + +OTHER_ARGS=() +while [[ $# -gt 0 ]] +do +key="$1" + +case $key in + -a|--arch) + CURRENT_TARGET="$2" + shift + shift + ;; + -m|--mvn) + BUILD_USING_MAVEN="mvn" + shift + ;; + *) + OTHER_ARGS+=("$1") + shift + ;; +esac +done + +CC_URL32="https://developer.arm.com/-/media/Files/downloads/gnu-a/8.3-2019.03/binrel/gcc-arm-8.3-2019.03-x86_64-arm-linux-gnueabihf.tar.xz?revision=e09a1c45-0ed3-4a8e-b06b-db3978fd8d56&la=en&hash=93ED4444B8B3A812B893373B490B90BBB28FD2E3" +CC_URL64="https://developer.arm.com/-/media/Files/downloads/gnu-a/8.3-2019.03/binrel/gcc-arm-8.3-2019.03-x86_64-aarch64-linux-gnu.tar.xz?revision=2e88a73f-d233-4f96-b1f4-d8b36e9bb0b9&la=en&hash=167687FADA00B73D20EED2A67D0939A197504ACD" +CC_ANDROID="https://dl.google.com/android/repository/android-ndk-r21d-linux-x86_64.zip" +COMPILER_ARRS=( "${CC_URL32}" "${CC_URL64}" "${CC_ANDROID}" "${CC_ANDROID}" ) +COMPILER_DOWNLOAD_CMD_LIST=( download_extract_xz download_extract_xz download_extract_unzip download_extract_unzip ) +COMPILER_DESTDIR=( "arm32" "arm64" "android" "android" ) +PREFIXES=( arm-linux-gnueabihf aarch64-linux-gnu arm-linux-androideabi aarch64-linux-android ) +TARGET_INDEX=-1 + +for i in "${!TARGET_ARRS[@]}"; do + if [[ "${TARGET_ARRS[$i]}" = "${CURRENT_TARGET}" ]]; then + TARGET_INDEX=${i} + fi +done + +if [ ${TARGET_INDEX} -eq -1 ];then + message "could not find ${CURRENT_TARGET} in ${TARGET_ARRS[@]}" + exit -1 +fi + +#BASE_DIR=${HOME}/pi +#https://stackoverflow.com/questions/59895/how-to-get-the-source-directory-of-a-bash-script-from-within-the-script-itself +SOURCE="${BASH_SOURCE[0]}" +while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symlink + DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" + SOURCE="$(readlink "$SOURCE")" + [[ $SOURCE != /* ]] && SOURCE="$DIR/$SOURCE" # if $SOURCE was a relative symlink, we need to resolve it relative to the path where the symlink file was located +done +BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" + +export CROSS_COMPILER_URL="https://developer.nvidia.com/embedded/dlc/l4t-gcc-toolchain-64-bit-32-5" +export CROSS_COMPILER_DIR=${BASE_DIR}/compile_tools/cross_compiler_${COMPILER_DESTDIR[$TARGET_INDEX]} +export COMPILER_DOWNLOAD_CMD=${COMPILER_DOWNLOAD_CMD_LIST[$TARGET_INDEX]} +export DETECT=${DETECT_LIST[$TARGET_INDEX]} +export LIBND4J_PLATFORM_EXT=${LIBND4J_PLATFORM_EXT_LIST[$TARGET_INDEX]} +export TARGET_OS="linux" +export LIBND4J_PLATFORM="linux-arm64" +export PREFIX=${PREFIXES[$TARGET_INDEX]} + +export CMAKE=cmake #/snap/bin/cmake +mkdir -p ${BASE_DIR}/compile_tools/ + + +mkdir -p ${BASE_DIR} +mkdir -p ${THIRD_PARTY} + +#change directory to base +cd $BASE_DIR + +function check_requirements { + for i in "${@}" + do + if [ ! -e "$i" ]; then + message "missing: ${i}" + exit -2 + fi + done +} + +function rename_top_folder { + for dir in ${1}/* + do + if [ -d "$dir" ] + then + mv "${dir}" "${1}/folder/" + message "${dir} => ${1}/folder/" + break + fi + done +} + +function download_extract_base { + #$1 is url #2 is dir $3 is extract argument + if [ ! -f ${3}_file ]; then + message "download" + wget --quiet --show-progress -O ${3}_file ${2} + fi + + message "extract $@" + #extract + mkdir -p ${3} + if [ ${1} = "-unzip" ]; then + command="unzip -qq ${3}_file -d ${3} " + else + command="tar ${1} ${3}_file --directory=${3} " + fi + message $command + $command + check_requirements "${3}" +} + +function download_extract { + download_extract_base -xzf $@ +} + +function download_extract_xz { + download_extract_base -xf $@ +} + +function download_extract_unzip { + download_extract_base -unzip $@ +} + +function git_check { + #$1 is url #$2 is dir #$3 is tag or branch if optional + command= + if [ -n "$3" ]; then + command="git clone --quiet --depth 1 --branch ${3} ${1} ${2}" + else + command="git clone --quiet ${1} ${2}" + fi + message "$command" + $command + check_requirements "${2}" +} + +#fix py debug linkage manually and also makes it use gold +function fix_pi_linker { + #$1 BINUTILS folder + if [ ! -f ${1}/ld.original ]; then + mv ${1}/ld ${1}/ld.original + fi + rm -f ${1}/ld + printf '#!/usr/bin/env bash\n'"${1}/ld.gold --long-plt \$*">${1}/ld + chmod +x ${1}/ld +} + +if [ ! -d ${CROSS_COMPILER_DIR}/folder ]; then + #out file + message "download CROSS_COMPILER" + ${COMPILER_DOWNLOAD_CMD} ${CROSS_COMPILER_URL} ${CROSS_COMPILER_DIR} + message "rename top folder (instead of --strip-components=1)" + rename_top_folder ${CROSS_COMPILER_DIR} +fi + +export CROSS_COMPILER_DIR=${CROSS_COMPILER_DIR}/folder +export BINUTILS_BIN=${CROSS_COMPILER_DIR}/${PREFIX}/bin +export COMPILER_PREFIX=${CROSS_COMPILER_DIR}/bin/${PREFIX} +export TOOLCHAIN_PREFIX=${COMPILER_PREFIX} +export SYS_ROOT=${CROSS_COMPILER_DIR}/${PREFIX}/libc +#LD_LIBRARY_PATH=${CROSS_COMPILER_DIR}/lib:$LD_LIBRARY_PATH +export CC_EXE="gcc" +export CXX_EXE="g++" +export RANLIB="${BINUTILS_BIN}/ranlib" +export LD="${BINUTILS_BIN}/ld" +export AR="${BINUTILS_BIN}/ar" +export BLAS_XTRA="CC=${COMPILER_PREFIX}-${CC_EXE} AR=${AR} RANLIB=${RANLIB} CFLAGS=--sysroot=${SYS_ROOT} LDFLAGS=\"-L${SYS_ROOT}/../lib/ -lm\"" + + +check_requirements ${CC} + + +#because of the toolchain passive detection we have to delete build folder manually +detect=$(cat ${BASE_DIR}/blasbuild/cpu/CMakeCache.txt | grep -o ${PREFIX}) +if [ -z "${detect}" ] ;then +message "remove blasbuild folder " +rm -rf $BASE_DIR/blasbuild/ +else +message "keep blasbuild folder" +fi + +if [ -z "${BUILD_USING_MAVEN}" ] ;then +message "lets build just library" +DHELPER=" -h armcompute " +bash ./buildnativeoperations.sh -o ${LIBND4J_PLATFORM} -t ${DHELPER} -j $(nproc) +else +message "cd $BASE_DIR/.. " +cd $BASE_DIR/.. +message "lets build jars" +export DHELPER=" -Dlibnd4j.helper=armcompute " +if [ "${DEPLOY}" ]; then + echo "Deploying to maven" + mvn -Pgithub deploy --batch-mode -Dlibnd4j.platform=${LIBND4J_PLATFORM} -Djavacpp.platform=${LIBND4J_PLATFORM} -DprotocCommand=protoc -Djavacpp.platform.compiler=${COMPILER_PREFIX}-${CC_EXE} -Djava.library.path=${JAVA_LIBRARY_PATH} ${DHELPER} -pl ":libnd4j,:nd4j-native" --also-make -DskipTests -Dmaven.test.skip=true -Dmaven.javadoc.skip=true + else + echo "Installing to local repo" + mvn install -Dlibnd4j.platform=${LIBND4J_PLATFORM} -Djavacpp.platform=${LIBND4J_PLATFORM} -DprotocCommand=protoc -Djavacpp.platform.compiler=${COMPILER_PREFIX}-${CC_EXE} -Djava.library.path=${JAVA_LIBRARY_PATH} ${DHELPER} -pl ":libnd4j" --also-make -DskipTests -Dmaven.test.skip=true -Dmaven.javadoc.skip=true +fi + +fi diff --git a/libnd4j/pi_build.sh b/libnd4j/pi_build.sh index 5bc3e0109..8a536d155 100755 --- a/libnd4j/pi_build.sh +++ b/libnd4j/pi_build.sh @@ -1,4 +1,4 @@ -#!/bin/bash +#!/usr/bin/env bash # # /* ****************************************************************************** # * @@ -22,14 +22,15 @@ function message { echo "BUILDER:::: ${@}" } +if [ -z "${BUILD_USING_MAVEN}" ]; then export BUILD_USING_MAVEN=; fi +if [ -z "${CURRENT_TARGET}" ]; then export CURRENT_TARGET=arm32; fi +if [ -z "${HAS_ARMCOMPUTE}" ]; then export ARMCOMPUTE_DEBUG=1; fi +if [ -z "${ARMCOMPUTE_DEBUG}" ]; then export HAS_ARMCOMPUTE=1; fi +if [ -z "${ARMCOMPUTE_TAG}" ]; then export ARMCOMPUTE_TAG=v20.05; fi +if [ -z "${LIBND4J_BUILD_MODE}" ]; then export LIBND4J_BUILD_MODE=Release; fi +if [ -z "${ANDROID_VERSION}" ]; then export ANDROID_VERSION=21; fi +if [ -z "${HAS_ARMCOMPUTE}" ]; then export HAS_ARMCOMPUTE=1; fi -BUILD_USING_MAVEN= -CURRENT_TARGET=arm32 -HAS_ARMCOMPUTE=1 -ARMCOMPUTE_DEBUG=0 -ARMCOMPUTE_TAG=v20.05 -LIBND4J_BUILD_MODE=Release -export ANDROID_VERSION=21 OTHER_ARGS=() while [[ $# -gt 0 ]] do @@ -88,18 +89,18 @@ while [ -h "$SOURCE" ]; do # resolve $SOURCE until the file is no longer a symli done BASE_DIR="$( cd -P "$( dirname "$SOURCE" )" >/dev/null 2>&1 && pwd )" -CROSS_COMPILER_URL=${COMPILER_ARRS[$TARGET_INDEX]} -CROSS_COMPILER_DIR=${BASE_DIR}/compile_tools/cross_compiler_${COMPILER_DESTDIR[$TARGET_INDEX]} -COMPILER_DOWNLOAD_CMD=${COMPILER_DOWNLOAD_CMD_LIST[$TARGET_INDEX]} -DETECT=${DETECT_LIST[$TARGET_INDEX]} -LIBND4J_PLATFORM_EXT=${LIBND4J_PLATFORM_EXT_LIST[$TARGET_INDEX]} -BLAS_TARGET_NAME=${OPENBLAS_TARGETS[$TARGET_INDEX]} -ARMCOMPUTE_TARGET=${ARMCOMPUTE_TARGETS[$TARGET_INDEX]} -TARGET_OS=${OS_LIST[$TARGET_INDEX]} -LIBND4J_PLATFORM=${TARGET_OS}-${LIBND4J_PLATFORM_EXT} -PREFIX=${PREFIXES[$TARGET_INDEX]} +export CROSS_COMPILER_URL=${COMPILER_ARRS[$TARGET_INDEX]} +export CROSS_COMPILER_DIR=${BASE_DIR}/compile_tools/cross_compiler_${COMPILER_DESTDIR[$TARGET_INDEX]} +export COMPILER_DOWNLOAD_CMD=${COMPILER_DOWNLOAD_CMD_LIST[$TARGET_INDEX]} +export DETECT=${DETECT_LIST[$TARGET_INDEX]} +export LIBND4J_PLATFORM_EXT=${LIBND4J_PLATFORM_EXT_LIST[$TARGET_INDEX]} +export BLAS_TARGET_NAME=${OPENBLAS_TARGETS[$TARGET_INDEX]} +export ARMCOMPUTE_TARGET=${ARMCOMPUTE_TARGETS[$TARGET_INDEX]} +export TARGET_OS=${OS_LIST[$TARGET_INDEX]} +export LIBND4J_PLATFORM=${TARGET_OS}-${LIBND4J_PLATFORM_EXT} +export PREFIX=${PREFIXES[$TARGET_INDEX]} -CMAKE=cmake #/snap/bin/cmake +export CMAKE=cmake #/snap/bin/cmake mkdir -p ${BASE_DIR}/compile_tools/ SCONS_LOCAL_URL=http://prdownloads.sourceforge.net/scons/scons-local-3.1.1.tar.gz @@ -148,10 +149,10 @@ function download_extract_base { message "download" wget --quiet --show-progress -O ${3}_file ${2} fi - + message "extract $@" #extract - mkdir -p ${3} + mkdir -p ${3} if [ ${1} = "-unzip" ]; then command="unzip -qq ${3}_file -d ${3} " else @@ -178,12 +179,12 @@ function git_check { #$1 is url #$2 is dir #$3 is tag or branch if optional command= if [ -n "$3" ]; then - command="git clone --quiet --depth 1 --branch ${3} ${1} ${2}" - else + command="git clone --quiet --depth 1 --branch ${3} ${1} ${2}" + else command="git clone --quiet ${1} ${2}" fi message "$command" - $command + $command check_requirements "${2}" } @@ -195,7 +196,7 @@ function fix_pi_linker { fi rm -f ${1}/ld printf '#!/usr/bin/env bash\n'"${1}/ld.gold --long-plt \$*">${1}/ld - chmod +x ${1}/ld + chmod +x ${1}/ld } if [ ! -d ${CROSS_COMPILER_DIR}/folder ]; then @@ -206,46 +207,46 @@ if [ ! -d ${CROSS_COMPILER_DIR}/folder ]; then rename_top_folder ${CROSS_COMPILER_DIR} fi -CROSS_COMPILER_DIR=${CROSS_COMPILER_DIR}/folder +export CROSS_COMPILER_DIR=${CROSS_COMPILER_DIR}/folder if [ "${TARGET_OS}" = "android" ];then - ANDROID_TOOLCHAIN=${CROSS_COMPILER_DIR}/toolchains/llvm/prebuilt/linux-x86_64 - COMPILER_PREFIX="${ANDROID_TOOLCHAIN}/bin/${PREFIX}${ANDROID_VERSION}" - TOOLCHAIN_PREFIX="${ANDROID_TOOLCHAIN}/bin/${PREFIX}" + export ANDROID_TOOLCHAIN=${CROSS_COMPILER_DIR}/toolchains/llvm/prebuilt/linux-x86_64 + export COMPILER_PREFIX="${ANDROID_TOOLCHAIN}/bin/${PREFIX}${ANDROID_VERSION}" + export TOOLCHAIN_PREFIX="${ANDROID_TOOLCHAIN}/bin/${PREFIX}" if [ "$BLAS_TARGET_NAME" = "ARMV7" ];then BLAS_XTRA="ARM_SOFTFP_ABI=1 " COMPILER_PREFIX="${ANDROID_TOOLCHAIN}/bin/armv7a-linux-androideabi${ANDROID_VERSION}" fi - CC_EXE="clang" - CXX_EXE="clang++" - AR="${TOOLCHAIN_PREFIX}-ar" - RANLIB="${TOOLCHAIN_PREFIX}-ranlib" - BLAS_XTRA="CC=${COMPILER_PREFIX}-${CC_EXE} AR=${AR} RANLIB=${RANLIB} ${BLAS_XTRA}" + export CC_EXE="clang" + export CXX_EXE="clang++" + export AR="${TOOLCHAIN_PREFIX}-ar" + export RANLIB="${TOOLCHAIN_PREFIX}-ranlib" + export BLAS_XTRA="CC=${COMPILER_PREFIX}-${CC_EXE} AR=${AR} RANLIB=${RANLIB} ${BLAS_XTRA}" else - BINUTILS_BIN=${CROSS_COMPILER_DIR}/${PREFIX}/bin - COMPILER_PREFIX=${CROSS_COMPILER_DIR}/bin/${PREFIX} - TOOLCHAIN_PREFIX=${COMPILER_PREFIX} - SYS_ROOT=${CROSS_COMPILER_DIR}/${PREFIX}/libc + export BINUTILS_BIN=${CROSS_COMPILER_DIR}/${PREFIX}/bin + export COMPILER_PREFIX=${CROSS_COMPILER_DIR}/bin/${PREFIX} + export TOOLCHAIN_PREFIX=${COMPILER_PREFIX} + export SYS_ROOT=${CROSS_COMPILER_DIR}/${PREFIX}/libc #LD_LIBRARY_PATH=${CROSS_COMPILER_DIR}/lib:$LD_LIBRARY_PATH - CC_EXE="gcc" - CXX_EXE="g++" - RANLIB="${BINUTILS_BIN}/ranlib" + export CC_EXE="gcc" + export CXX_EXE="g++" + export RANLIB="${BINUTILS_BIN}/ranlib" export LD="${BINUTILS_BIN}/ld" - AR="${BINUTILS_BIN}/ar" - BLAS_XTRA="CC=${COMPILER_PREFIX}-${CC_EXE} AR=${AR} RANLIB=${RANLIB} CFLAGS=--sysroot=${SYS_ROOT} LDFLAGS=\"-L${SYS_ROOT}/../lib/ -lm\"" + export AR="${BINUTILS_BIN}/ar" + export BLAS_XTRA="CC=${COMPILER_PREFIX}-${CC_EXE} AR=${AR} RANLIB=${RANLIB} CFLAGS=--sysroot=${SYS_ROOT} LDFLAGS=\"-L${SYS_ROOT}/../lib/ -lm\"" fi check_requirements ${CC} if [ -z "${BUILD_USING_MAVEN}" ] ;then -#lets build OpenBlas +#lets build OpenBlas if [ ! -d "${OPENBLAS_DIR}" ]; then message "download OpenBLAS" git_check "${OPENBLAS_GIT_URL}" "${OPENBLAS_DIR}" "v0.3.10" fi if [ ! -f "${THIRD_PARTY}/lib/libopenblas.so" ]; then - message "build and install OpenBLAS" + message "build and install OpenBLAS" cd ${OPENBLAS_DIR} command="make TARGET=${BLAS_TARGET_NAME} HOSTCC=gcc NOFORTRAN=1 ${BLAS_XTRA} " @@ -271,9 +272,9 @@ if [ ! -d ${SCONS_LOCAL_DIR} ]; then fi check_requirements ${SCONS_LOCAL_DIR}/scons.py -if [ ! -d "${ARMCOMPUTE_DIR}" ]; then - message "download ArmCompute Source" - git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "${ARMCOMPUTE_TAG}" +if [ ! -d "${ARMCOMPUTE_DIR}" ]; then + message "download ArmCompute Source" + git_check ${ARMCOMPUTE_GIT_URL} "${ARMCOMPUTE_DIR}" "${ARMCOMPUTE_TAG}" fi #build armcompute @@ -283,7 +284,7 @@ cd ${ARMCOMPUTE_DIR} command="CC=${CC_EXE} CXX=${CXX_EXE} python3 ${SCONS_LOCAL_DIR}/scons.py Werror=1 -j$(nproc) toolchain_prefix=${TOOLCHAIN_PREFIX}- compiler_prefix=${COMPILER_PREFIX}- debug=${ARMCOMPUTE_DEBUG} neon=1 opencl=0 extra_cxx_flags=-fPIC os=${TARGET_OS} build=cross_compile arch=${ARMCOMPUTE_TARGET} " message $command eval $command &>/dev/null -cd ${BASE_DIR} +cd ${BASE_DIR} fi check_requirements "${ARMCOMPUTE_DIR}/build/libarm_compute-static.a" "${ARMCOMPUTE_DIR}/build/libarm_compute_core-static.a" @@ -293,7 +294,7 @@ if [ "${TARGET_OS}" = "android" ];then export ANDROID_NDK=${CROSS_COMPILER_DIR} else export RPI_BIN=${CROSS_COMPILER_DIR}/bin/${PREFIX} - export JAVA_LIBRARY_PATH=${CROSS_COMPILER_DIR}/${PREFIX}/lib + export JAVA_LIBRARY_PATH=${CROSS_COMPILER_DIR}/${PREFIX}/lib fix_pi_linker ${BINUTILS_BIN} fi @@ -315,6 +316,13 @@ else message "cd $BASE_DIR/.. " cd $BASE_DIR/.. message "lets build jars" -DHELPER=" -Dlibnd4j.helper=armcompute " -mvn install -Dlibnd4j.platform=${LIBND4J_PLATFORM} -Djavacpp.platform=${LIBND4J_PLATFORM} -DprotocCommand=protoc -Djavacpp.platform.compiler=${COMPILER_PREFIX}-${CC_EXE} -Djava.library.path=${JAVA_LIBRARY_PATH} ${DHELPER} -Dmaven.test.skip=true -Dmaven.javadoc.skip=true +export DHELPER=" -Dlibnd4j.helper=armcompute " +if [ "${DEPLOY}" ]; then + echo "Deploying to maven" + mvn -P"${PUBLISH_TO}" deploy --batch-mode -Dlibnd4j.platform=${LIBND4J_PLATFORM} -Djavacpp.platform=${LIBND4J_PLATFORM} -DprotocCommand=protoc -Djavacpp.platform.compiler=${COMPILER_PREFIX}-${CC_EXE} -Djava.library.path=${JAVA_LIBRARY_PATH} ${DHELPER} -pl ":libnd4j,:nd4j-native" --also-make -DskipTests -Dmaven.test.skip=true -Dmaven.javadoc.skip=true + else + echo "Installing to local repo" + mvn install -Dlibnd4j.platform=${LIBND4J_PLATFORM} -Djavacpp.platform=${LIBND4J_PLATFORM} -DprotocCommand=protoc -Djavacpp.platform.compiler=${COMPILER_PREFIX}-${CC_EXE} -Djava.library.path=${JAVA_LIBRARY_PATH} ${DHELPER} -pl ":libnd4j" --also-make -DskipTests -Dmaven.test.skip=true -Dmaven.javadoc.skip=true +fi + fi diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index c06a99e0a..592e643fd 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -56,7 +56,7 @@ if [ -n "$BUILD_PATH" ]; then export PATH="$PATH:$BUILD_PATH" fi -../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests +../blasbuild/${CHIP}/te NdArrayIpcTeststs_cpu/layers_tests/runtests # Workaround to fix posix path conversion problem on Windows (http://mingw.org/wiki/Posix_path_conversion) [ -f "${GTEST_OUTPUT#*:}" ] && cp -a surefire-reports/ ../target && rm -rf surefire-reports/ diff --git a/nd4j/nd4j-serde/nd4j-aeron/pom.xml b/nd4j/nd4j-serde/nd4j-aeron/pom.xml index f868bbd89..236004a0c 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/pom.xml +++ b/nd4j/nd4j-serde/nd4j-aeron/pom.xml @@ -39,7 +39,7 @@ 1.8 1.8 1.5.4 - 1.4.0 + 1.32.0 @@ -93,7 +93,7 @@ - ${env.LD_LIBRARY_PATH}:${user.dir}:${libnd4jhome}/blasbuild/cpu/blas/ + ${env.LD_LIBRARY_PATH}${path.separator}${user.dir}${path.separator}${libnd4jhome}/blasbuild/cpu/blas/${path.separator}${libnd4jhome}/../nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/target/classes src/test/java @@ -119,7 +119,6 @@ For testing large zoo models, this may not be enough (so comment it out). --> - -Ddtype=float -Dfile.encoding=UTF-8 -Xmx8g diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java index c373e712e..97e83a5aa 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/AeronUtil.java @@ -59,13 +59,15 @@ public class AeronUtil { ipcLength += 2; // System.setProperty("aeron.term.buffer.size",String.valueOf(ipcLength)); final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) - /* .ipcTermBufferLength(ipcLength) - .publicationTermBufferLength(ipcLength) - .maxTermBufferLength(ipcLength)*/ - .conductorIdleStrategy(new BusySpinIdleStrategy()) - .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED) + .dirDeleteOnStart(true) + .dirDeleteOnShutdown(true) + /* .ipcTermBufferLength(ipcLength) + .publicationTermBufferLength(ipcLength) + .maxTermBufferLength(ipcLength)*/ + .conductorIdleStrategy(new BusySpinIdleStrategy()) + .receiverIdleStrategy(new BusySpinIdleStrategy()) + .senderIdleStrategy(new BusySpinIdleStrategy()); return ctx; } @@ -92,7 +94,7 @@ public class AeronUtil { * @return loop function */ public static Consumer subscriberLoop(final FragmentHandler fragmentHandler, final int limit, - final AtomicBoolean running, final AtomicBoolean launched) { + final AtomicBoolean running, final AtomicBoolean launched) { final IdleStrategy idleStrategy = new BusySpinIdleStrategy(); return subscriberLoop(fragmentHandler, limit, running, idleStrategy, launched); } @@ -109,7 +111,7 @@ public class AeronUtil { * @return loop function */ public static Consumer subscriberLoop(final FragmentHandler fragmentHandler, final int limit, - final AtomicBoolean running, final IdleStrategy idleStrategy, final AtomicBoolean launched) { + final AtomicBoolean running, final IdleStrategy idleStrategy, final AtomicBoolean launched) { return (subscription) -> { try { while (running.get()) { @@ -134,7 +136,7 @@ public class AeronUtil { buffer.getBytes(offset, data); System.out.println(String.format("Message to stream %d from session %d (%d@%d) <<%s>>", streamId, - header.sessionId(), length, offset, new String(data))); + header.sessionId(), length, offset, new String(data))); }; } @@ -149,7 +151,7 @@ public class AeronUtil { * @param cause of the error */ public static void printError(final String channel, final int streamId, final int sessionId, final String message, - final HeaderFlyweight cause) { + final HeaderFlyweight cause) { System.out.println(message); } @@ -162,9 +164,9 @@ public class AeronUtil { * @param totalBytes being reported */ public static void printRate(final double messagesPerSec, final double bytesPerSec, final long totalMessages, - final long totalBytes) { + final long totalBytes) { System.out.println(String.format("%.02g msgs/sec, %.02g bytes/sec, totals %d messages %d MB", messagesPerSec, - bytesPerSec, totalMessages, totalBytes / (1024 * 1024))); + bytesPerSec, totalMessages, totalBytes / (1024 * 1024))); } /** @@ -175,7 +177,7 @@ public class AeronUtil { public static void printAvailableImage(final Image image) { final Subscription subscription = image.subscription(); System.out.println(String.format("Available image on %s streamId=%d sessionId=%d from %s", - subscription.channel(), subscription.streamId(), image.sessionId(), image.sourceIdentity())); + subscription.channel(), subscription.streamId(), image.sessionId(), image.sourceIdentity())); } /** @@ -186,7 +188,7 @@ public class AeronUtil { public static void printUnavailableImage(final Image image) { final Subscription subscription = image.subscription(); System.out.println(String.format("Unavailable image on %s streamId=%d sessionId=%d", subscription.channel(), - subscription.streamId(), image.sessionId())); + subscription.streamId(), image.sessionId())); } private static final AtomicInteger conductorCount = new AtomicInteger(); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/LowLatencyMediaDriver.java b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/LowLatencyMediaDriver.java index e366a0c9d..d5d1da518 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/LowLatencyMediaDriver.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/main/java/org/nd4j/aeron/ipc/LowLatencyMediaDriver.java @@ -34,8 +34,7 @@ public class LowLatencyMediaDriver { @SuppressWarnings("checkstyle:UncommentedMain") public static void main(final String... args) { - MediaDriver.loadPropertiesFiles(args); - + MediaDriver.main(args); setProperty(DISABLE_BOUNDS_CHECKS_PROP_NAME, "true"); setProperty("aeron.mtu.length", "16384"); setProperty("aeron.socket.so_sndbuf", "2097152"); @@ -43,10 +42,11 @@ public class LowLatencyMediaDriver { setProperty("aeron.rcv.initial.window.length", "2097152"); final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirsDeleteOnStart(true) - .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) - .receiverIdleStrategy(new BusySpinIdleStrategy()) - .senderIdleStrategy(new BusySpinIdleStrategy()); + new MediaDriver.Context().threadingMode(ThreadingMode.DEDICATED).dirDeleteOnStart(true) + .dirDeleteOnShutdown(true) + .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) + .receiverIdleStrategy(new BusySpinIdleStrategy()) + .senderIdleStrategy(new BusySpinIdleStrategy()); try (MediaDriver ignored = MediaDriver.launch(ctx)) { new SigIntBarrier().await(); diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java index eb0a400ff..b28937c62 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/AeronNDArraySerdeTest.java @@ -28,13 +28,14 @@ import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; import java.io.BufferedOutputStream; import java.io.ByteArrayOutputStream; import java.io.DataOutputStream; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; - +@NotThreadSafe public class AeronNDArraySerdeTest extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java index 832b2c9c5..af4bb515c 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/LargeNdArrayIpcTest.java @@ -31,11 +31,13 @@ import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertFalse; @Slf4j +@NotThreadSafe public class LargeNdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private Aeron.Context ctx; @@ -73,9 +75,10 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { int length = (int) 1e7; INDArray arr = Nd4j.ones(length); AeronNDArrayPublisher publisher; - ctx = new Aeron.Context().publicationConnectionTimeout(-1).availableImageHandler(AeronUtil::printAvailableImage) + ctx = new Aeron.Context() + .driverTimeoutMs(-1).availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000) .errorHandler(err -> err.printStackTrace()); final AtomicBoolean running = new AtomicBoolean(true); @@ -149,10 +152,10 @@ public class LargeNdArrayIpcTest extends BaseND4JTest { private Aeron.Context getContext() { if (ctx == null) - ctx = new Aeron.Context().publicationConnectionTimeout(-1) + ctx = new Aeron.Context().driverTimeoutMs(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(10000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(10000) .errorHandler(err -> err.printStackTrace()); return ctx; } diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java index ffc4e04e6..0a8b89277 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NDArrayMessageTest.java @@ -26,8 +26,11 @@ import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; + import static org.junit.Assert.assertEquals; +@NotThreadSafe public class NDArrayMessageTest extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java index 253df0082..6dac31259 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/NdArrayIpcTest.java @@ -32,12 +32,14 @@ import org.nd4j.linalg.factory.Nd4j; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import javax.annotation.concurrent.NotThreadSafe; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import static org.junit.Assert.assertFalse; +@NotThreadSafe public class NdArrayIpcTest extends BaseND4JTest { private MediaDriver mediaDriver; private static Logger log = LoggerFactory.getLogger(NdArrayIpcTest.class); @@ -223,10 +225,10 @@ public class NdArrayIpcTest extends BaseND4JTest { private Aeron.Context getContext() { if (ctx == null) - ctx = new Aeron.Context().publicationConnectionTimeout(1000) + ctx = new Aeron.Context().driverTimeoutMs(1000) .availableImageHandler(image -> System.out.println(image)) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000) .errorHandler(e -> log.error(e.toString(), e)); return ctx; } diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java index 62b724760..d3b89ef48 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/ChunkAccumulatorTests.java @@ -25,8 +25,11 @@ import org.nd4j.common.tests.BaseND4JTest; import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; + import static org.junit.Assert.assertEquals; +@NotThreadSafe public class ChunkAccumulatorTests extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java index 8df55f0bd..8ff2eae34 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/chunk/NDArrayMessageChunkTests.java @@ -27,11 +27,13 @@ import org.nd4j.aeron.ipc.NDArrayMessage; import org.nd4j.aeron.util.BufferUtil; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; import java.nio.ByteBuffer; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +@NotThreadSafe public class NDArrayMessageChunkTests extends BaseND4JTest { @Test diff --git a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java index 7a663e690..1c4c46acd 100644 --- a/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java +++ b/nd4j/nd4j-serde/nd4j-aeron/src/test/java/org/nd4j/aeron/ipc/response/AeronNDArrayResponseTest.java @@ -33,12 +33,14 @@ import org.nd4j.aeron.ipc.*; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; +import javax.annotation.concurrent.NotThreadSafe; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import static org.junit.Assert.assertEquals; @Slf4j +@NotThreadSafe public class AeronNDArrayResponseTest extends BaseND4JTest { private MediaDriver mediaDriver; @@ -51,7 +53,8 @@ public class AeronNDArrayResponseTest extends BaseND4JTest { public void before() { if(isIntegrationTests()) { final MediaDriver.Context ctx = - new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirsDeleteOnStart(true) + new MediaDriver.Context().threadingMode(ThreadingMode.SHARED).dirDeleteOnShutdown(true) + .dirDeleteOnStart(true) .termBufferSparseFile(false).conductorIdleStrategy(new BusySpinIdleStrategy()) .receiverIdleStrategy(new BusySpinIdleStrategy()) .senderIdleStrategy(new BusySpinIdleStrategy()); @@ -69,10 +72,10 @@ public class AeronNDArrayResponseTest extends BaseND4JTest { int streamId = 10; int responderStreamId = 11; String host = "127.0.0.1"; - Aeron.Context ctx = new Aeron.Context().publicationConnectionTimeout(-1) + Aeron.Context ctx = new Aeron.Context().driverTimeoutMs(-1) .availableImageHandler(AeronUtil::printAvailableImage) .unavailableImageHandler(AeronUtil::printUnavailableImage) - .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveInterval(1000) + .aeronDirectoryName(mediaDriver.aeronDirectoryName()).keepAliveIntervalNs(1000) .errorHandler(e -> log.error(e.toString(), e)); int baseSubscriberPort = 40123 + new java.util.Random().nextInt(1000); diff --git a/pom.xml b/pom.xml index 3dce515f3..034af675e 100644 --- a/pom.xml +++ b/pom.xml @@ -92,18 +92,7 @@ - - - sonatype-nexus-releases - Nexus Release Repository - http://oss.sonatype.org/service/local/staging/deploy/maven2/ - - - sonatype-nexus-snapshots - Sonatype Nexus snapshot repository - https://oss.sonatype.org/content/repositories/snapshots - - + 1.7 @@ -184,7 +173,9 @@ ${javacpp.platform} - + + + 1.5.4 1.5.4 1.5.4 @@ -323,24 +314,23 @@ 1.0.0 2.2.0 1.4.30 + 1.3 - - - org.jetbrains.kotlin - kotlin-stdlib-jdk8 - ${kotlin.version} - - - org.jetbrains.kotlin - kotlin-test - ${kotlin.version} - test - - + + org.jetbrains.kotlin + kotlin-stdlib-jdk8 + ${kotlin.version} + + + org.jetbrains.kotlin + kotlin-test + ${kotlin.version} + test + io.netty netty-all @@ -421,6 +411,12 @@ netty-codec-dns ${netty.version} + + org.walkmod + junit4git + ${junit4git.version} + test + @@ -428,6 +424,58 @@ + + org.jetbrains.kotlin + kotlin-maven-plugin + ${kotlin.version} + + + -Xjsr305=strict + + + spring + jpa + + + + + org.jetbrains.kotlin + kotlin-maven-allopen + 1.4.30-M1 + + + org.jetbrains.kotlin + kotlin-maven-noarg + 1.4.30-M1 + + + + + compile + compile + + + ${project.basedir}/src/main/stubs + ${project.basedir}/src/main/kotlin + ${project.basedir}/src/main/java + ${project.basedir}/src/main/ops + + + + + test-compile + test-compile + + + ${project.basedir}/src/test/stubs + ${project.basedir}/src/test/kotlin + ${project.basedir}/src/test/java + ${project.basedir}/src/test/ops + + + + + org.apache.maven.plugins maven-compiler-plugin @@ -488,6 +536,7 @@ true + false ${project.basedir}/target/generated-sources/src/main/resources/org/eclipse/${project.groupId}-${project.artifactId}-git.properties @@ -659,6 +708,40 @@ + + github + + + github + GitHub Packages + https://maven.pkg.github.com/eclipse/deeplearning4j + + + + github + Github snapshots + https://maven.pkg.github.com/eclipse/deeplearning4j + + + + + + ossrh + + + sonatype-nexus-releases + Nexus Release Repository + http://oss.sonatype.org/service/local/staging/deploy/maven2/ + + + + sonatype-nexus-snapshots + Sonatype Nexus snapshot repository + https://oss.sonatype.org/content/repositories/snapshots + + + + skipTestCompileAndRun @@ -715,6 +798,10 @@ ${dl4j-test-resources.classifier} test + + org.walkmod + junit4git + @@ -769,6 +856,12 @@ org.apache.maven.plugins maven-gpg-plugin ${maven-gpg-plugin.version} + + + --pinentry-mode + loopback + + sign-artifacts @@ -1023,10 +1116,16 @@ true + + + listener + org.walkmod.junit4git.junit4.Junit4GitListener + + - + \ No newline at end of file