diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml deleted file mode 100644 index c69e1abcb..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-client - - datavec-spark-inference-client - - - - org.datavec - datavec-spark-inference-server_2.11 - 1.0.0-SNAPSHOT - test - - - org.datavec - datavec-spark-inference-model - ${project.parent.version} - - - com.mashape.unirest - unirest-java - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java deleted file mode 100644 index 8a346b096..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java +++ /dev/null @@ -1,292 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.client; - - -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.http.exceptions.UnirestException; -import lombok.AllArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.*; -import org.datavec.spark.inference.model.service.DataVecTransformService; -import org.nd4j.shade.jackson.core.JsonProcessingException; - -import java.io.IOException; - -@AllArgsConstructor -@Slf4j -public class DataVecTransformClient implements DataVecTransformService { - private String url; - - static { - // Only one time - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - }); - } - - /** - * @param transformProcess - */ - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - try { - String s = transformProcess.toJson(); - Unirest.post(url + "/transformprocess").header("accept", "application/json") - .header("Content-Type", "application/json").body(s).asJson(); - - } catch (UnirestException e) { - log.error("Error in setCSVTransformProcess()", e); - } - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @return - */ - @Override - public TransformProcess getCSVTransformProcess() { - try { - String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") - .header("Content-Type", "application/json").asString().getBody(); - return TransformProcess.fromJson(s); - } catch (UnirestException e) { - log.error("Error in getCSVTransformProcess()",e); - } - - return null; - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @param transform - * @return - */ - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { - try { - SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .body(transform).asObject(SingleCSVRecord.class).getBody(); - return singleCsvRecord; - } catch (UnirestException e) { - log.error("Error in transformIncremental(SingleCSVRecord)",e); - } - return null; - } - - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - try { - SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"TRUE") - .body(batchCSVRecord) - .asObject(SequenceBatchCSVRecord.class) - .getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("",e); - } - - return null; - } - /** - * @param batchCSVRecord - * @return - */ - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - try { - BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"FALSE") - .body(batchCSVRecord) - .asObject(BatchCSVRecord.class) - .getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("Error in transform(BatchCSVRecord)", e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - try { - Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batchCSVRecord) - .asObject(Base64NDArrayBody.class).getBody(); - return batchArray1; - } catch (UnirestException e) { - log.error("Error in transformArray(BatchCSVRecord)",e); - } - - return null; - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - try { - Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); - return array; - } catch (UnirestException e) { - log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); - } - - return null; - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - try { - Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); - return array; - } catch (UnirestException e) { - log.error("Error in transformSequenceArrayIncremental",e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - try { - Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(batchCSVRecord) - .asObject(Base64NDArrayBody.class).getBody(); - return batchArray1; - } catch (UnirestException e) { - log.error("Error in transformSequenceArray",e); - } - - return null; - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - try { - SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(batchCSVRecord) - .asObject(SequenceBatchCSVRecord.class).getBody(); - return batchCSVRecord1; - } catch (UnirestException e) { - log.error("Error in transformSequence"); - } - - return null; - } - - /** - * @param transform - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - try { - SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") - .header("accept", "application/json") - .header("Content-Type", "application/json") - .header(SEQUENCE_OR_NOT_HEADER,"true") - .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); - return singleCsvRecord; - } catch (UnirestException e) { - log.error("Error in transformSequenceIncremental"); - } - return null; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java deleted file mode 100644 index de2970b27..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.transform.client; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.transform.client"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java deleted file mode 100644 index 6619ec443..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java +++ /dev/null @@ -1,139 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.transform.client; - -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.client.DataVecTransformClient; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.net.ServerSocket; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assume.assumeNotNull; - -public class DataVecTransformClientTest { - private static CSVSparkTransformServer server; - private static int port = getAvailablePort(); - private static DataVecTransformClient client; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void beforeClass() throws Exception { - FileUtils.write(fileSave, transformProcess.toJson()); - fileSave.deleteOnExit(); - server = new CSVSparkTransformServer(); - server.runMain(new String[] {"-dp", String.valueOf(port)}); - - client = new DataVecTransformClient("http://localhost:" + port); - client.setCSVTransformProcess(transformProcess); - } - - @AfterClass - public static void afterClass() throws Exception { - server.stop(); - } - - - @Test - public void testSequenceClient() { - SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); - List batchCSVRecordList = new ArrayList<>(); - for(int i = 0; i < 5; i++) { - batchCSVRecordList.add(batchCSVRecord); - } - - sequenceBatchCSVRecord.add(batchCSVRecordList); - - SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord); - assumeNotNull(sequenceBatchCSVRecord1); - - Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord); - assumeNotNull(array); - - Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord); - assumeNotNull(incrementalBody); - - Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord); - assumeNotNull(incrementalSequenceBody); - } - - @Test - public void testRecord() throws Exception { - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord); - assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size()); - Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord); - INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); - assumeNotNull(arr); - } - - @Test - public void testBatchRecord() throws Exception { - SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); - BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord); - assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size()); - - Base64NDArrayBody body = client.transformArray(batchCSVRecord); - INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); - assumeNotNull(arr); - } - - - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf deleted file mode 100644 index dbac92d83..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf +++ /dev/null @@ -1,6 +0,0 @@ -play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule -play.modules.enabled += io.skymind.skil.service.PredictionModule -play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk -play.server.pidfile.path=/tmp/RUNNING_PID - -play.server.http.port = 9600 diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml deleted file mode 100644 index fe9ca985a..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml +++ /dev/null @@ -1,63 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-model - - datavec-spark-inference-model - - - - org.datavec - datavec-api - ${datavec.version} - - - org.datavec - datavec-data-image - - - org.datavec - datavec-local - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java deleted file mode 100644 index e081708e0..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java +++ /dev/null @@ -1,286 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.arrow.memory.BufferAllocator; -import org.apache.arrow.memory.RootAllocator; -import org.apache.arrow.vector.FieldVector; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.util.ndarray.RecordConverter; -import org.datavec.api.writable.Writable; -import org.datavec.arrow.ArrowConverter; -import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; -import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; -import org.datavec.local.transforms.LocalTransformExecutor; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import static org.datavec.arrow.ArrowConverter.*; -import static org.datavec.local.transforms.LocalTransformExecutor.execute; -import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence; - -@AllArgsConstructor -@Slf4j -public class CSVSparkTransform { - @Getter - private TransformProcess transformProcess; - private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); - - /** - * Convert a raw record via - * the {@link TransformProcess} - * to a base 64ed ndarray - * @param batch the record to convert - * @return teh base 64ed ndarray - * @throws IOException - */ - public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException { - List> converted = execute(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - batch.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - - ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted; - INDArray convert = ArrowConverter.toArray(arrowRecordBatch); - return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); - } - - /** - * Convert a raw record via - * the {@link TransformProcess} - * to a base 64ed ndarray - * @param record the record to convert - * @return the base 64ed ndarray - * @throws IOException - */ - public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException { - List record2 = toArrowWritablesSingle( - toArrowColumnsStringSingle(bufferAllocator, - transformProcess.getInitialSchema(),record.getValues()), - transformProcess.getInitialSchema()); - List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); - return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); - } - - /** - * Runs the transform process - * @param batch the record to transform - * @return the transformed record - */ - public BatchCSVRecord transform(BatchCSVRecord batch) { - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - List> converted = execute(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - batch.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - int numCols = converted.get(0).size(); - for (int row = 0; row < converted.size(); row++) { - String[] values = new String[numCols]; - for (int i = 0; i < values.length; i++) - values[i] = converted.get(row).get(i).toString(); - batchCSVRecord.add(new SingleCSVRecord(values)); - } - - return batchCSVRecord; - - } - - /** - * Runs the transform process - * @param record the record to transform - * @return the transformed record - */ - public SingleCSVRecord transform(SingleCSVRecord record) { - List record2 = toArrowWritablesSingle( - toArrowColumnsStringSingle(bufferAllocator, - transformProcess.getInitialSchema(),record.getValues()), - transformProcess.getInitialSchema()); - List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); - String[] values = new String[finalRecord.size()]; - for (int i = 0; i < values.length; i++) - values[i] = finalRecord.get(i).toString(); - return new SingleCSVRecord(values); - - } - - /** - * - * @param transform - * @return - */ - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - /** - * Sequence schema? - */ - List>> converted = executeToSequence( - toArrowWritables(toArrowColumnsStringTimeSeries( - bufferAllocator, transformProcess.getInitialSchema(), - Arrays.asList(transform.getRecordsAsString())), - transformProcess.getInitialSchema()), transformProcess); - - SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); - for (int i = 0; i < converted.size(); i++) { - BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i)); - batchCSVRecord.add(Arrays.asList(batchCSVRecord1)); - } - - return batchCSVRecord; - } - - /** - * - * @param batchCSVRecordSequence - * @return - */ - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) { - List>> recordsAsString = batchCSVRecordSequence.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : recordsAsString) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors, - transformProcess.getInitialSchema(), - recordsAsString.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - - } - } - - /** - * TODO: optimize - * @param batchCSVRecordSequence - * @return - */ - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) { - List>> strings = batchCSVRecordSequence.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : strings) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - throw new IllegalStateException(e); - } - } - - } - - /** - * - * @param singleCsvRecord - * @return - */ - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - List>> converted = executeToSequence(toArrowWritables(toArrowColumnsString( - bufferAllocator,transformProcess.getInitialSchema(), - singleCsvRecord.getRecordsAsString()), - transformProcess.getInitialSchema()),transformProcess); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted; - INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch); - try { - return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); - } catch (IOException e) { - log.error("",e); - } - - return null; - } - - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - List>> strings = batchCSVRecord.getRecordsAsString(); - boolean allSameLength = true; - Integer length = null; - for(List> record : strings) { - if(length == null) { - length = record.size(); - } - else if(record.size() != length) { - allSameLength = false; - } - } - - if(allSameLength) { - List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); - ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); - val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - } - - else { - val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); - return SequenceBatchCSVRecord.fromWritables(transformed); - - } - - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java deleted file mode 100644 index a004c439b..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java +++ /dev/null @@ -1,64 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model; - -import lombok.AllArgsConstructor; -import lombok.Getter; -import org.datavec.image.data.ImageWritable; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -@AllArgsConstructor -public class ImageSparkTransform { - @Getter - private ImageTransformProcess imageTransformProcess; - - public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException { - ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri()); - INDArray finalRecord = imageTransformProcess.executeArray(record2); - - return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord)); - } - - public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException { - List records = new ArrayList<>(); - - for (SingleImageRecord imgRecord : batch.getRecords()) { - ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri()); - INDArray finalRecord = imageTransformProcess.executeArray(record2); - records.add(finalRecord); - } - - INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()])); - - return new Base64NDArrayBody(Nd4jBase64.base64String(array)); - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java deleted file mode 100644 index 0d6c680ad..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class Base64NDArrayBody { - private String ndarray; -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java deleted file mode 100644 index 82ecedc51..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class BatchCSVRecord implements Serializable { - private List records; - - - /** - * Get the records as a list of strings - * (basically the underlying values for - * {@link SingleCSVRecord}) - * @return - */ - public List> getRecordsAsString() { - if(records == null) - records = new ArrayList<>(); - List> ret = new ArrayList<>(); - for(SingleCSVRecord csvRecord : records) { - ret.add(csvRecord.getValues()); - } - return ret; - } - - - /** - * Create a batch csv record - * from a list of writables. - * @param batch - * @return - */ - public static BatchCSVRecord fromWritables(List> batch) { - List records = new ArrayList<>(batch.size()); - for(List list : batch) { - List add = new ArrayList<>(list.size()); - for(Writable writable : list) { - add.add(writable.toString()); - } - records.add(new SingleCSVRecord(add)); - } - - return BatchCSVRecord.builder().records(records).build(); - } - - - /** - * Add a record - * @param record - */ - public void add(SingleCSVRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static BatchCSVRecord fromDataSet(DataSet dataSet) { - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < dataSet.numExamples(); i++) { - batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i))); - } - - return batchCSVRecord; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java deleted file mode 100644 index ff101c659..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.net.URI; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class BatchImageRecord { - private List records; - - /** - * Add a record - * @param record - */ - public void add(SingleImageRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - public void add(URI uri) { - this.add(new SingleImageRecord(uri)); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java deleted file mode 100644 index eed4fac59..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.dataset.MultiDataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class SequenceBatchCSVRecord implements Serializable { - private List> records; - - /** - * Add a record - * @param record - */ - public void add(List record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - /** - * Get the records as a list of strings directly - * (this basically "unpacks" the objects) - * @return - */ - public List>> getRecordsAsString() { - if(records == null) - Collections.emptyList(); - List>> ret = new ArrayList<>(records.size()); - for(List record : records) { - List> add = new ArrayList<>(); - for(BatchCSVRecord batchCSVRecord : record) { - for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) { - add.add(singleCSVRecord.getValues()); - } - } - - ret.add(add); - } - - return ret; - } - - /** - * Convert a writables time series to a sequence batch - * @param input - * @return - */ - public static SequenceBatchCSVRecord fromWritables(List>> input) { - SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord(); - for(int i = 0; i < input.size(); i++) { - ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i)))); - } - - return ret; - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) { - SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); - for (int i = 0; i < dataSet.numFeatureArrays(); i++) { - batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i))))); - } - - return batchCSVRecord; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java deleted file mode 100644 index 575a91918..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.Arrays; -import java.util.List; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class SingleCSVRecord implements Serializable { - private List values; - - /** - * Create from an array of values uses list internally) - * @param values - */ - public SingleCSVRecord(String...values) { - this.values = Arrays.asList(values); - } - - /** - * Instantiate a csv record from a vector - * given either an input dataset and a - * one hot matrix, the index will be appended to - * the end of the record, or for regression - * it will append all values in the labels - * @param row the input vectors - * @return the record from this {@link DataSet} - */ - public static SingleCSVRecord fromRow(DataSet row) { - if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) - throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); - if (!row.getLabels().isVector() && !row.getLabels().isScalar()) - throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); - //classification - SingleCSVRecord record; - int idx = 0; - if (row.getLabels().sumNumber().doubleValue() == 1.0) { - String[] values = new String[row.getFeatures().columns() + 1]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - int maxIdx = 0; - for (int i = 0; i < row.getLabels().length(); i++) { - if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { - maxIdx = i; - } - } - - values[idx++] = String.valueOf(maxIdx); - record = new SingleCSVRecord(values); - } - //regression (any number of values) - else { - String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - for (int i = 0; i < row.getLabels().length(); i++) { - values[idx++] = String.valueOf(row.getLabels().getDouble(i)); - } - - - record = new SingleCSVRecord(values); - - } - return record; - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java deleted file mode 100644 index 9fe3df042..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.net.URI; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class SingleImageRecord { - private URI uri; -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java deleted file mode 100644 index c23dd562c..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java +++ /dev/null @@ -1,131 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.model.service; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.model.*; - -import java.io.IOException; - -public interface DataVecTransformService { - - String SEQUENCE_OR_NOT_HEADER = "Sequence"; - - - /** - * - * @param transformProcess - */ - void setCSVTransformProcess(TransformProcess transformProcess); - - /** - * - * @param imageTransformProcess - */ - void setImageTransformProcess(ImageTransformProcess imageTransformProcess); - - /** - * - * @return - */ - TransformProcess getCSVTransformProcess(); - - /** - * - * @return - */ - ImageTransformProcess getImageTransformProcess(); - - /** - * - * @param singleCsvRecord - * @return - */ - SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord); - - SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - BatchCSVRecord transform(BatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord); - - /** - * - * @param singleCsvRecord - * @return - */ - Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord); - - /** - * - * @param singleImageRecord - * @return - * @throws IOException - */ - Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException; - - /** - * - * @param batchImageRecord - * @return - * @throws IOException - */ - Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException; - - /** - * - * @param singleCsvRecord - * @return - */ - Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param batchCSVRecord - * @return - */ - SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord); - - /** - * - * @param transform - * @return - */ - SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform); -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java deleted file mode 100644 index ab76b206e..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.spark.transform; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.spark.transform"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java deleted file mode 100644 index a5ce6c474..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class BatchCSVRecordTest { - - @Test - public void testBatchRecordCreationFromDataSet() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); - - BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet); - assertEquals(2, batchCSVRecord.getRecords().size()); - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java deleted file mode 100644 index 7d1fe5f3b..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java +++ /dev/null @@ -1,212 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.api.transform.transform.integer.BaseIntegerTransform; -import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; -import org.datavec.api.writable.DoubleWritable; -import org.datavec.api.writable.Text; -import org.datavec.api.writable.Writable; -import org.datavec.spark.inference.model.CSVSparkTransform; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.util.*; - -import static org.junit.Assert.*; - -public class CSVSparkTransformTest { - @Test - public void testTransformer() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - assertTrue(fromBase64.isVector()); -// System.out.println("Base 64ed array " + fromBase64); - } - - @Test - public void testTransformerBatch() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(record); - //data type is string, unable to convert - BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); - /* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1); - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); - assertTrue(fromBase64.isMatrix()); - System.out.println("Base 64ed array " + fromBase64); */ - } - - - - @Test - public void testSingleBatchSequence() throws Exception { - List input = new ArrayList<>(); - input.add(new DoubleWritable(1.0)); - input.add(new DoubleWritable(2.0)); - - Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - List output = new ArrayList<>(); - output.add(new Text("1.0")); - output.add(new Text("2.0")); - - TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); - CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(record); - BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); - SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); - sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); - Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); - INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray()); - - - //ensure accumulation - sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); - sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); - assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape()); - - SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); - assertNotNull(transformed.getRecords()); -// System.out.println(transformed); - - - } - - @Test - public void testSpecificSequence() throws Exception { - final Schema schema = new Schema.Builder() - .addColumnsString("action") - .build(); - - final TransformProcess transformProcess = new TransformProcess.Builder(schema) - .removeAllColumnsExceptFor("action") - .transform(new ConverToLowercase("action")) - .convertToSequence() - .transform(new TextToCharacterIndexTransform("action", "action_sequence", - defaultCharIndex(), false)) - .integerToOneHot("action_sequence",0,29) - .build(); - - final String[] data1 = new String[] { "test1" }; - final String[] data2 = new String[] { "test2" }; - final BatchCSVRecord batchCsvRecord = new BatchCSVRecord( - Arrays.asList( - new SingleCSVRecord(data1), - new SingleCSVRecord(data2))); - - final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); -// System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); - transform.transformSequenceIncremental(batchCsvRecord); - assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); - - } - - private static Map defaultCharIndex() { - Map ret = new TreeMap<>(); - - ret.put('a',0); - ret.put('b',1); - ret.put('c',2); - ret.put('d',3); - ret.put('e',4); - ret.put('f',5); - ret.put('g',6); - ret.put('h',7); - ret.put('i',8); - ret.put('j',9); - ret.put('k',10); - ret.put('l',11); - ret.put('m',12); - ret.put('n',13); - ret.put('o',14); - ret.put('p',15); - ret.put('q',16); - ret.put('r',17); - ret.put('s',18); - ret.put('t',19); - ret.put('u',20); - ret.put('v',21); - ret.put('w',22); - ret.put('x',23); - ret.put('y',24); - ret.put('z',25); - ret.put('/',26); - ret.put(' ',27); - ret.put('(',28); - ret.put(')',29); - - return ret; - } - - public static class ConverToLowercase extends BaseIntegerTransform { - public ConverToLowercase(String column) { - super(column); - } - - public Text map(Writable writable) { - return new Text(writable.toString().toLowerCase()); - } - - public Object map(Object input) { - return new Text(input.toString().toLowerCase()); - } - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java deleted file mode 100644 index 415730b18..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.ImageSparkTransform; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; - -import static org.junit.Assert.assertEquals; - -public class ImageSparkTransformTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void testSingleImageSparkTransform() throws Exception { - int seed = 12345; - - File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); - - SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI()); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) - .scaleImageTransform(10).cropImageTransform(5).build(); - - ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); - Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); - - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); -// System.out.println("Base 64ed array " + fromBase64); - assertEquals(1, fromBase64.size(0)); - } - - @Test - public void testBatchImageSparkTransform() throws Exception { - int seed = 12345; - - File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); - File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile(); - File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(f0.toURI()); - batch.add(f1.toURI()); - batch.add(f2.toURI()); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) - .scaleImageTransform(10).cropImageTransform(5).build(); - - ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); - Base64NDArrayBody body = imgSparkTransform.toArray(batch); - - INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); -// System.out.println("Base 64ed array " + fromBase64); - assertEquals(3, fromBase64.size(0)); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java deleted file mode 100644 index 599f8eead..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.Test; -import org.nd4j.linalg.dataset.DataSet; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -public class SingleCSVRecordTest { - - @Test(expected = IllegalArgumentException.class) - public void testVectorAssertion() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1)); - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet); - fail(singleCsvRecord.toString() + " should have thrown an exception"); - } - - @Test - public void testVectorOneHotLabel() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}})); - - //assert - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); - assertEquals(3, singleCsvRecord.getValues().size()); - - } - - @Test - public void testVectorRegression() { - DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); - - //assert - SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); - assertEquals(4, singleCsvRecord.getValues().size()); - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java deleted file mode 100644 index 3c321e583..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.common.io.ClassPathResource; - -import java.io.File; - -public class SingleImageRecordTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - @Test - public void testImageRecord() throws Exception { - File f = testDir.newFolder(); - new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f); - File f0 = new File(f, "class0/0.jpg"); - File f1 = new File(f, "/class1/A.jpg"); - - SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI()); - - // need jackson test? - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml deleted file mode 100644 index 8a65942db..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml +++ /dev/null @@ -1,154 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-spark-inference-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-server_2.11 - - datavec-spark-inference-server - - - - 2.11.12 - 2.11 - 1.8 - 1.8 - - - - - org.datavec - datavec-spark-inference-model - ${datavec.version} - - - org.datavec - datavec-spark_2.11 - ${project.version} - - - org.datavec - datavec-data-image - - - joda-time - joda-time - - - org.apache.commons - commons-lang3 - - - org.hibernate - hibernate-validator - ${hibernate.version} - - - org.scala-lang - scala-library - ${scala.version} - - - org.scala-lang - scala-reflect - ${scala.version} - - - com.typesafe.play - play-java_2.11 - ${playframework.version} - - - com.google.code.findbugs - jsr305 - - - net.jodah - typetools - - - - - net.jodah - typetools - ${jodah.typetools.version} - - - com.typesafe.play - play-json_2.11 - ${playframework.version} - - - com.typesafe.play - play-server_2.11 - ${playframework.version} - - - com.typesafe.play - play_2.11 - ${playframework.version} - - - com.typesafe.play - play-netty-server_2.11 - ${playframework.version} - - - com.typesafe.akka - akka-cluster_2.11 - 2.5.23 - - - com.mashape.unirest - unirest-java - test - - - com.beust - jcommander - ${jcommander.version} - - - org.apache.spark - spark-core_2.11 - ${spark.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java deleted file mode 100644 index 9ef085515..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java +++ /dev/null @@ -1,352 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.ParameterException; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.CSVSparkTransform; -import org.datavec.spark.inference.model.model.*; -import play.BuiltInComponents; -import play.Mode; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; - -import java.io.File; -import java.io.IOException; -import java.util.Base64; -import java.util.Random; - -import static play.mvc.Results.*; - -@Slf4j -@Data -public class CSVSparkTransformServer extends SparkTransformServer { - private CSVSparkTransform transform; - - public void runMain(String[] args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch (ParameterException e) { - //User provides invalid input -> print the usage info - jcmdr.usage(); - if (jsonPath == null) - System.err.println("Json path parameter is missing."); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - if (jsonPath != null) { - String json = FileUtils.readFileToString(new File(jsonPath)); - TransformProcess transformProcess = TransformProcess.fromJson(json); - transform = new CSVSparkTransform(transformProcess); - } else { - log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" - + "to /transformprocess"); - } - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - - server = Server.forRouter(Mode.PROD, port, this::createRouter); - } - - protected Router createRouter(BuiltInComponents b){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(b); - - routingDsl.GET("/transformprocess").routingTo(req -> { - try { - if (transform == null) - return badRequest(); - return ok(transform.getTransformProcess().toJson()).as(contentType); - } catch (Exception e) { - log.error("Error in GET /transformprocess",e); - return internalServerError(e.getMessage()); - } - }); - - routingDsl.POST("/transformprocess").routingTo(req -> { - try { - TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); - setCSVTransformProcess(transformProcess); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); - } catch (Exception e) { - log.error("Error in POST /transformprocess",e); - return internalServerError(e.getMessage()); - } - }); - - routingDsl.POST("/transformincremental").routingTo(req -> { - if (isSequence(req)) { - try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincremental", e); - return internalServerError(e.getMessage()); - } - } else { - try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincremental", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transform").routingTo(req -> { - if (isSequence(req)) { - try { - SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(batch)).as(contentType); - } catch (Exception e) { - log.error("Error in /transform", e); - return internalServerError(e.getMessage()); - } - } else { - try { - BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - BatchCSVRecord batch = transform(input); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(batch)).as(contentType); - } catch (Exception e) { - log.error("Error in /transform", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transformincrementalarray").routingTo(req -> { - if (isSequence(req)) { - try { - BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincrementalarray", e); - return internalServerError(e.getMessage()); - } - } else { - try { - SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformincrementalarray", e); - return internalServerError(e.getMessage()); - } - } - }); - - routingDsl.POST("/transformarray").routingTo(req -> { - if (isSequence(req)) { - try { - SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); - if (batchCSVRecord == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformarray", e); - return internalServerError(e.getMessage()); - } - } else { - try { - BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); - if (batchCSVRecord == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); - } catch (Exception e) { - log.error("Error in /transformarray", e); - return internalServerError(e.getMessage()); - } - } - }); - - return routingDsl.build(); - } - - public static void main(String[] args) throws Exception { - new CSVSparkTransformServer().runMain(args); - } - - /** - * @param transformProcess - */ - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - this.transform = new CSVSparkTransform(transformProcess); - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - /** - * @return - */ - @Override - public TransformProcess getCSVTransformProcess() { - return transform.getTransformProcess(); - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - - /** - * - */ - /** - * @param transform - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - return this.transform.transformSequenceIncremental(transform); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - return transform.transformSequence(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - return this.transform.transformSequenceArray(batchCSVRecord); - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - return this.transform.transformSequenceArrayIncremental(singleCsvRecord); - } - - /** - * @param transform - * @return - */ - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { - return this.transform.transform(transform); - } - - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - return this.transform.transform(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - return transform.transform(batchCSVRecord); - } - - /** - * @param batchCSVRecord - * @return - */ - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - try { - return this.transform.toArray(batchCSVRecord); - } catch (IOException e) { - log.error("Error in transformArray",e); - throw new IllegalStateException("Transform array shouldn't throw exception"); - } - } - - /** - * @param singleCsvRecord - * @return - */ - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - try { - return this.transform.toArray(singleCsvRecord); - } catch (IOException e) { - log.error("Error in transformArrayIncremental",e); - throw new IllegalStateException("Transform array shouldn't throw exception"); - } - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { - log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { - log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass()); - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java deleted file mode 100644 index e7744ecaa..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.JCommander; -import com.beust.jcommander.ParameterException; -import lombok.Data; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.model.ImageSparkTransform; -import org.datavec.spark.inference.model.model.*; -import play.BuiltInComponents; -import play.Mode; -import play.libs.Files; -import play.mvc.Http; -import play.routing.Router; -import play.routing.RoutingDsl; -import play.server.Server; - -import java.io.File; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -import static play.mvc.Results.*; - -@Slf4j -@Data -public class ImageSparkTransformServer extends SparkTransformServer { - private ImageSparkTransform transform; - - public void runMain(String[] args) throws Exception { - JCommander jcmdr = new JCommander(this); - - try { - jcmdr.parse(args); - } catch (ParameterException e) { - //User provides invalid input -> print the usage info - jcmdr.usage(); - if (jsonPath == null) - System.err.println("Json path parameter is missing."); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - if (jsonPath != null) { - String json = FileUtils.readFileToString(new File(jsonPath)); - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); - transform = new ImageSparkTransform(transformProcess); - } else { - log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" - + "to /transformprocess"); - } - - server = Server.forRouter(Mode.PROD, port, this::createRouter); - } - - protected Router createRouter(BuiltInComponents builtInComponents){ - RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); - - routingDsl.GET("/transformprocess").routingTo(req -> { - try { - if (transform == null) - return badRequest(); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformprocess").routingTo(req -> { - try { - ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); - setImageTransformProcess(transformProcess); - log.info("Transform process initialized"); - return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformincrementalarray").routingTo(req -> { - try { - SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); - if (record == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformincrementalimage").routingTo(req -> { - try { - Http.MultipartFormData body = req.body().asMultipartFormData(); - List> files = body.getFiles(); - if (files.isEmpty() || files.get(0).getRef() == null ) { - return badRequest(); - } - - File file = files.get(0).getRef().path().toFile(); - SingleImageRecord record = new SingleImageRecord(file.toURI()); - - return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformarray").routingTo(req -> { - try { - BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); - if (batch == null) - return badRequest(); - return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - routingDsl.POST("/transformimage").routingTo(req -> { - try { - Http.MultipartFormData body = req.body().asMultipartFormData(); - List> files = body.getFiles(); - if (files.size() == 0) { - return badRequest(); - } - - List records = new ArrayList<>(); - - for (Http.MultipartFormData.FilePart filePart : files) { - Files.TemporaryFile file = filePart.getRef(); - if (file != null) { - SingleImageRecord record = new SingleImageRecord(file.path().toUri()); - records.add(record); - } - } - - BatchImageRecord batch = new BatchImageRecord(records); - - return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); - } catch (Exception e) { - log.error("",e); - return internalServerError(); - } - }); - - return routingDsl.build(); - } - - @Override - public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException(); - } - - @Override - public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException(); - - } - - @Override - public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException(); - - } - - @Override - public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { - throw new UnsupportedOperationException(); - - } - - @Override - public void setCSVTransformProcess(TransformProcess transformProcess) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { - this.transform = new ImageSparkTransform(imageTransformProcess); - } - - @Override - public TransformProcess getCSVTransformProcess() { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public ImageTransformProcess getImageTransformProcess() { - return transform.getImageTransformProcess(); - } - - @Override - public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { - throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); - } - - @Override - public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException { - return transform.toArray(record); - } - - @Override - public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException { - return transform.toArray(batch); - } - - public static void main(String[] args) throws Exception { - new ImageSparkTransformServer().runMain(args); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java deleted file mode 100644 index c89ef90cc..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import com.beust.jcommander.Parameter; -import com.fasterxml.jackson.databind.JsonNode; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.service.DataVecTransformService; -import org.nd4j.shade.jackson.databind.ObjectMapper; -import play.mvc.Http; -import play.server.Server; - -public abstract class SparkTransformServer implements DataVecTransformService { - @Parameter(names = {"-j", "--jsonPath"}, arity = 1) - protected String jsonPath = null; - @Parameter(names = {"-dp", "--dataVecPort"}, arity = 1) - protected int port = 9000; - @Parameter(names = {"-dt", "--dataType"}, arity = 1) - private TransformDataType transformDataType = null; - protected Server server; - protected static ObjectMapper objectMapper = new ObjectMapper(); - protected static String contentType = "application/json"; - - public abstract void runMain(String[] args) throws Exception; - - /** - * Stop the server - */ - public void stop() { - if (server != null) - server.stop(); - } - - protected boolean isSequence(Http.Request request) { - return request.hasHeader(SEQUENCE_OR_NOT_HEADER) - && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); - } - - protected String getJsonText(Http.Request request) { - JsonNode tryJson = request.body().asJson(); - if (tryJson != null) - return tryJson.toString(); - else - return request.body().asText(); - } - - public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java deleted file mode 100644 index aa4945ddb..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java +++ /dev/null @@ -1,76 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -import lombok.Data; -import lombok.extern.slf4j.Slf4j; - -import java.io.InvalidClassException; -import java.util.Arrays; -import java.util.List; - -@Data -@Slf4j -public class SparkTransformServerChooser { - private SparkTransformServer sparkTransformServer = null; - private TransformDataType transformDataType = null; - - public void runMain(String[] args) throws Exception { - - int pos = getMatchingPosition(args, "-dt", "--dataType"); - if (pos == -1) { - log.error("no valid options"); - log.error("-dt, --dataType Options: [CSV, IMAGE]"); - throw new Exception("no valid options"); - } else { - transformDataType = TransformDataType.valueOf(args[pos + 1]); - } - - switch (transformDataType) { - case CSV: - sparkTransformServer = new CSVSparkTransformServer(); - break; - case IMAGE: - sparkTransformServer = new ImageSparkTransformServer(); - break; - default: - throw new InvalidClassException("no matching SparkTransform class"); - } - - sparkTransformServer.runMain(args); - } - - private int getMatchingPosition(String[] args, String... options) { - List optionList = Arrays.asList(options); - - for (int i = 0; i < args.length; i++) { - if (optionList.contains(args[i])) { - return i; - } - } - return -1; - } - - - public static void main(String[] args) throws Exception { - new SparkTransformServerChooser().runMain(args); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java deleted file mode 100644 index 643cd5652..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.inference.server; - -public enum TransformDataType { - CSV, IMAGE, -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf deleted file mode 100644 index 28a4aa208..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf +++ /dev/null @@ -1,350 +0,0 @@ -# This is the main configuration file for the application. -# https://www.playframework.com/documentation/latest/ConfigFile -# ~~~~~ -# Play uses HOCON as its configuration file format. HOCON has a number -# of advantages over other config formats, but there are two things that -# can be used when modifying settings. -# -# You can include other configuration files in this main application.conf file: -#include "extra-config.conf" -# -# You can declare variables and substitute for them: -#mykey = ${some.value} -# -# And if an environment variable exists when there is no other subsitution, then -# HOCON will fall back to substituting environment variable: -#mykey = ${JAVA_HOME} - -## Akka -# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration -# https://www.playframework.com/documentation/latest/JavaAkka#Configuration -# ~~~~~ -# Play uses Akka internally and exposes Akka Streams and actors in Websockets and -# other streaming HTTP responses. -akka { - # "akka.log-config-on-start" is extraordinarly useful because it log the complete - # configuration at INFO level, including defaults and overrides, so it s worth - # putting at the very top. - # - # Put the following in your conf/logback.xml file: - # - # - # - # And then uncomment this line to debug the configuration. - # - #log-config-on-start = true -} - -## Modules -# https://www.playframework.com/documentation/latest/Modules -# ~~~~~ -# Control which modules are loaded when Play starts. Note that modules are -# the replacement for "GlobalSettings", which are deprecated in 2.5.x. -# Please see https://www.playframework.com/documentation/latest/GlobalSettings -# for more information. -# -# You can also extend Play functionality by using one of the publically available -# Play modules: https://playframework.com/documentation/latest/ModuleDirectory -play.modules { - # By default, Play will load any class called Module that is defined - # in the root package (the "app" directory), or you can define them - # explicitly below. - # If there are any built-in modules that you want to disable, you can list them here. - #enabled += my.application.Module - - # If there are any built-in modules that you want to disable, you can list them here. - #disabled += "" -} - -## Internationalisation -# https://www.playframework.com/documentation/latest/JavaI18N -# https://www.playframework.com/documentation/latest/ScalaI18N -# ~~~~~ -# Play comes with its own i18n settings, which allow the user's preferred language -# to map through to internal messages, or allow the language to be stored in a cookie. -play.i18n { - # The application languages - langs = [ "en" ] - - # Whether the language cookie should be secure or not - #langCookieSecure = true - - # Whether the HTTP only attribute of the cookie should be set to true - #langCookieHttpOnly = true -} - -## Play HTTP settings -# ~~~~~ -play.http { - ## Router - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # Define the Router object to use for this application. - # This router will be looked up first when the application is starting up, - # so make sure this is the entry point. - # Furthermore, it's assumed your route file is named properly. - # So for an application router like `my.application.Router`, - # you may need to define a router file `conf/my.application.routes`. - # Default to Routes in the root package (aka "apps" folder) (and conf/routes) - #router = my.application.Router - - ## Action Creator - # https://www.playframework.com/documentation/latest/JavaActionCreator - # ~~~~~ - #actionCreator = null - - ## ErrorHandler - # https://www.playframework.com/documentation/latest/JavaRouting - # https://www.playframework.com/documentation/latest/ScalaRouting - # ~~~~~ - # If null, will attempt to load a class called ErrorHandler in the root package, - #errorHandler = null - - ## Filters - # https://www.playframework.com/documentation/latest/ScalaHttpFilters - # https://www.playframework.com/documentation/latest/JavaHttpFilters - # ~~~~~ - # Filters run code on every request. They can be used to perform - # common logic for all your actions, e.g. adding common headers. - # Defaults to "Filters" in the root package (aka "apps" folder) - # Alternatively you can explicitly register a class here. - #filters += my.application.Filters - - ## Session & Flash - # https://www.playframework.com/documentation/latest/JavaSessionFlash - # https://www.playframework.com/documentation/latest/ScalaSessionFlash - # ~~~~~ - session { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - - # Sets the max-age field of the cookie to 5 minutes. - # NOTE: this only sets when the browser will discard the cookie. Play will consider any - # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, - # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. - #maxAge = 300 - - # Sets the domain on the session cookie. - #domain = "example.com" - } - - flash { - # Sets the cookie to be sent only over HTTPS. - #secure = true - - # Sets the cookie to be accessed only by the server. - #httpOnly = true - } -} - -## Netty Provider -# https://www.playframework.com/documentation/latest/SettingsNetty -# ~~~~~ -play.server.netty { - # Whether the Netty wire should be logged - #log.wire = true - - # If you run Play on Linux, you can use Netty's native socket transport - # for higher performance with less garbage. - #transport = "native" -} - -## WS (HTTP Client) -# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS -# ~~~~~ -# The HTTP client primarily used for REST APIs. The default client can be -# configured directly, but you can also create different client instances -# with customized settings. You must enable this by adding to build.sbt: -# -# libraryDependencies += ws // or javaWs if using java -# -play.ws { - # Sets HTTP requests not to follow 302 requests - #followRedirects = false - - # Sets the maximum number of open HTTP connections for the client. - #ahc.maxConnectionsTotal = 50 - - ## WS SSL - # https://www.playframework.com/documentation/latest/WsSSL - # ~~~~~ - ssl { - # Configuring HTTPS with Play WS does not require programming. You can - # set up both trustManager and keyManager for mutual authentication, and - # turn on JSSE debugging in development with a reload. - #debug.handshake = true - #trustManager = { - # stores = [ - # { type = "JKS", path = "exampletrust.jks" } - # ] - #} - } -} - -## Cache -# https://www.playframework.com/documentation/latest/JavaCache -# https://www.playframework.com/documentation/latest/ScalaCache -# ~~~~~ -# Play comes with an integrated cache API that can reduce the operational -# overhead of repeated requests. You must enable this by adding to build.sbt: -# -# libraryDependencies += cache -# -play.cache { - # If you want to bind several caches, you can bind the individually - #bindCaches = ["db-cache", "user-cache", "session-cache"] -} - -## Filters -# https://www.playframework.com/documentation/latest/Filters -# ~~~~~ -# There are a number of built-in filters that can be enabled and configured -# to give Play greater security. You must enable this by adding to build.sbt: -# -# libraryDependencies += filters -# -play.filters { - ## CORS filter configuration - # https://www.playframework.com/documentation/latest/CorsFilter - # ~~~~~ - # CORS is a protocol that allows web applications to make requests from the browser - # across different domains. - # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has - # dependencies on CORS settings. - cors { - # Filter paths by a whitelist of path prefixes - #pathPrefixes = ["/some/path", ...] - - # The allowed origins. If null, all origins are allowed. - #allowedOrigins = ["http://www.example.com"] - - # The allowed HTTP methods. If null, all methods are allowed - #allowedHttpMethods = ["GET", "POST"] - } - - ## CSRF Filter - # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter - # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter - # ~~~~~ - # Play supports multiple methods for verifying that a request is not a CSRF request. - # The primary mechanism is a CSRF token. This token gets placed either in the query string - # or body of every form submitted, and also gets placed in the users session. - # Play then verifies that both tokens are present and match. - csrf { - # Sets the cookie to be sent only over HTTPS - #cookie.secure = true - - # Defaults to CSRFErrorHandler in the root package. - #errorHandler = MyCSRFErrorHandler - } - - ## Security headers filter configuration - # https://www.playframework.com/documentation/latest/SecurityHeaders - # ~~~~~ - # Defines security headers that prevent XSS attacks. - # If enabled, then all options are set to the below configuration by default: - headers { - # The X-Frame-Options header. If null, the header is not set. - #frameOptions = "DENY" - - # The X-XSS-Protection header. If null, the header is not set. - #xssProtection = "1; mode=block" - - # The X-Content-Type-Options header. If null, the header is not set. - #contentTypeOptions = "nosniff" - - # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. - #permittedCrossDomainPolicies = "master-only" - - # The Content-Security-Policy header. If null, the header is not set. - #contentSecurityPolicy = "default-src 'self'" - } - - ## Allowed hosts filter configuration - # https://www.playframework.com/documentation/latest/AllowedHostsFilter - # ~~~~~ - # Play provides a filter that lets you configure which hosts can access your application. - # This is useful to prevent cache poisoning attacks. - hosts { - # Allow requests to example.com, its subdomains, and localhost:9000. - #allowed = [".example.com", "localhost:9000"] - } -} - -## Evolutions -# https://www.playframework.com/documentation/latest/Evolutions -# ~~~~~ -# Evolutions allows database scripts to be automatically run on startup in dev mode -# for database migrations. You must enable this by adding to build.sbt: -# -# libraryDependencies += evolutions -# -play.evolutions { - # You can disable evolutions for a specific datasource if necessary - #db.default.enabled = false -} - -## Database Connection Pool -# https://www.playframework.com/documentation/latest/SettingsJDBC -# ~~~~~ -# Play doesn't require a JDBC database to run, but you can easily enable one. -# -# libraryDependencies += jdbc -# -play.db { - # The combination of these two settings results in "db.default" as the - # default JDBC pool: - #config = "db" - #default = "default" - - # Play uses HikariCP as the default connection pool. You can override - # settings by changing the prototype: - prototype { - # Sets a fixed JDBC connection pool size of 50 - #hikaricp.minimumIdle = 50 - #hikaricp.maximumPoolSize = 50 - } -} - -## JDBC Datasource -# https://www.playframework.com/documentation/latest/JavaDatabase -# https://www.playframework.com/documentation/latest/ScalaDatabase -# ~~~~~ -# Once JDBC datasource is set up, you can work with several different -# database options: -# -# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick -# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA -# EBean: https://playframework.com/documentation/latest/JavaEbean -# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm -# -db { - # You can declare as many datasources as you want. - # By convention, the default datasource is named `default` - - # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database - default.driver = org.h2.Driver - default.url = "jdbc:h2:mem:play" - #default.username = sa - #default.password = "" - - # You can expose this datasource via JNDI if needed (Useful for JPA) - default.jndiName=DefaultDS - - # You can turn on SQL logging for any datasource - # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements - #default.logSql=true -} - -jpa.default=defaultPersistenceUnit - - -#Increase default maximum post length - used for remote listener functionality -#Can get response 413 with larger networks without setting this -# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead -#parsers.text.maxLength=10M -play.http.parser.maxMemoryBuffer=10M diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java deleted file mode 100644 index ab76b206e..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ -package org.datavec.spark.transform; - -import lombok.extern.slf4j.Slf4j; -import org.nd4j.common.tests.AbstractAssertTestsClass; -import org.nd4j.common.tests.BaseND4JTest; - -import java.util.*; - -@Slf4j -public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { - - @Override - protected Set> getExclusions() { - //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) - return new HashSet<>(); - } - - @Override - protected String getPackageName() { - return "org.datavec.spark.transform"; - } - - @Override - protected Class getBaseClass() { - return BaseND4JTest.class; - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java deleted file mode 100644 index 8f309caff..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java +++ /dev/null @@ -1,127 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeNotNull; - -public class CSVSparkTransformServerNoJsonTest { - - private static CSVSparkTransformServer server; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new CSVSparkTransformServer(); - FileUtils.write(fileSave, transformProcess.toJson()); - - // Only one time - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"-dp", "9050"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.delete(); - server.stop(); - - } - - - - @Test - public void testServer() throws Exception { - assertTrue(server.getTransform() == null); - JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(transformProcess.toJson()).asJson().getBody(); - assumeNotNull(server.getTransform()); - - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - /* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); -*/ - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java deleted file mode 100644 index a3af5f2c6..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.spark.inference.server.CSVSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchCSVRecord; -import org.datavec.spark.inference.model.model.SingleCSVRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -public class CSVSparkTransformServerTest { - - private static CSVSparkTransformServer server; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new CSVSparkTransformServer(); - FileUtils.write(fileSave, transformProcess.toJson()); - // Only one time - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.deleteOnExit(); - server.stop(); - - } - - - - @Test - public void testServer() throws Exception { - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - - - - } - -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java deleted file mode 100644 index 12f754acd..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java +++ /dev/null @@ -1,164 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.server.ImageSparkTransformServer; -import org.datavec.spark.inference.model.model.Base64NDArrayBody; -import org.datavec.spark.inference.model.model.BatchImageRecord; -import org.datavec.spark.inference.model.model.SingleImageRecord; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; - -public class ImageSparkTransformServerTest { - - @Rule - public TemporaryFolder testDir = new TemporaryFolder(); - - private static ImageSparkTransformServer server; - private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - server = new ImageSparkTransformServer(); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) - .scaleImageTransform(10).cropImageTransform(5).build(); - - FileUtils.write(fileSave, imgTransformProcess.toJson()); - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"}); - } - - @AfterClass - public static void after() throws Exception { - fileSave.deleteOnExit(); - server.stop(); - - } - - @Test - public void testImageServer() throws Exception { - SingleImageRecord record = - new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asJson().getBody(); - Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); - - JsonNode jsonNodeBatch = - Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batch).asJson().getBody(); - Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(batch) - .asObject(Base64NDArrayBody.class).getBody(); - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - - INDArray batchResult = getNDArray(jsonNodeBatch); - assertEquals(3, batchResult.size(0)); - -// System.out.println(array); - } - - @Test - public void testImageServerMultipart() throws Exception { - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") - .header("accept", "application/json") - .field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile()) - .field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile()) - .field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile()) - .asJson().getBody(); - - - INDArray batchResult = getNDArray(jsonNode); - assertEquals(3, batchResult.size(0)); - -// System.out.println(batchResult); - } - - @Test - public void testImageServerSingleMultipart() throws Exception { - File f = testDir.newFolder(); - File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f); - - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") - .header("accept", "application/json") - .field("file1", imgFile) - .asJson().getBody(); - - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - -// System.out.println(result); - } - - public INDArray getNDArray(JsonNode node) throws IOException { - return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java deleted file mode 100644 index 831dd24f4..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java +++ /dev/null @@ -1,168 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.datavec.spark.transform; - - -import com.mashape.unirest.http.JsonNode; -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import org.apache.commons.io.FileUtils; -import org.datavec.api.transform.TransformProcess; -import org.datavec.api.transform.schema.Schema; -import org.datavec.image.transform.ImageTransformProcess; -import org.datavec.spark.inference.server.SparkTransformServerChooser; -import org.datavec.spark.inference.server.TransformDataType; -import org.datavec.spark.inference.model.model.*; -import org.junit.AfterClass; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.io.ClassPathResource; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.File; -import java.io.IOException; -import java.util.UUID; - -import static org.junit.Assert.assertEquals; - -public class SparkTransformServerTest { - private static SparkTransformServerChooser serverChooser; - private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); - private static TransformProcess transformProcess = - new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build(); - - private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json"); - private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json"); - - @BeforeClass - public static void before() throws Exception { - serverChooser = new SparkTransformServerChooser(); - - ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) - .scaleImageTransform(10).cropImageTransform(5).build(); - - FileUtils.write(imageTransformFile, imgTransformProcess.toJson()); - - FileUtils.write(csvTransformFile, transformProcess.toJson()); - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - }); - - - } - - @AfterClass - public static void after() throws Exception { - imageTransformFile.deleteOnExit(); - csvTransformFile.deleteOnExit(); - } - - @Test - public void testImageServer() throws Exception { - serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt", - TransformDataType.IMAGE.toString()}); - - SingleImageRecord record = - new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asJson().getBody(); - Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - BatchImageRecord batch = new BatchImageRecord(); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); - batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); - - JsonNode jsonNodeBatch = - Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") - .header("Content-Type", "application/json").body(batch).asJson().getBody(); - Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(batch) - .asObject(Base64NDArrayBody.class).getBody(); - - INDArray result = getNDArray(jsonNode); - assertEquals(1, result.size(0)); - - INDArray batchResult = getNDArray(jsonNodeBatch); - assertEquals(3, batchResult.size(0)); - - serverChooser.getSparkTransformServer().stop(); - } - - @Test - public void testCSVServer() throws Exception { - serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt", - TransformDataType.CSV.toString()}); - - String[] values = new String[] {"1.0", "2.0"}; - SingleCSVRecord record = new SingleCSVRecord(values); - JsonNode jsonNode = - Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") - .header("Content-Type", "application/json").body(record).asJson().getBody(); - SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(SingleCSVRecord.class).getBody(); - - BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); - for (int i = 0; i < 3; i++) - batchCSVRecord.add(singleCsvRecord); - BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); - - Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") - .header("accept", "application/json").header("Content-Type", "application/json").body(record) - .asObject(Base64NDArrayBody.class).getBody(); - - Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") - .header("accept", "application/json").header("Content-Type", "application/json") - .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); - - - serverChooser.getSparkTransformServer().stop(); - } - - public INDArray getNDArray(JsonNode node) throws IOException { - return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); - } -} diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf deleted file mode 100644 index dbac92d83..000000000 --- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf +++ /dev/null @@ -1,6 +0,0 @@ -play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule -play.modules.enabled += io.skymind.skil.service.PredictionModule -play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk -play.server.pidfile.path=/tmp/RUNNING_PID - -play.server.http.port = 9600 diff --git a/datavec/datavec-spark-inference-parent/pom.xml b/datavec/datavec-spark-inference-parent/pom.xml deleted file mode 100644 index abf3f3b0d..000000000 --- a/datavec/datavec-spark-inference-parent/pom.xml +++ /dev/null @@ -1,68 +0,0 @@ - - - - - - 4.0.0 - - - org.datavec - datavec-parent - 1.0.0-SNAPSHOT - - - datavec-spark-inference-parent - pom - - datavec-spark-inference-parent - - - datavec-spark-inference-server - datavec-spark-inference-client - datavec-spark-inference-model - - - - - - org.datavec - datavec-data-image - ${datavec.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/datavec/pom.xml b/datavec/pom.xml index 4142db170..d1c46077f 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -45,7 +45,6 @@ datavec-data datavec-spark datavec-local - datavec-spark-inference-parent datavec-jdbc datavec-excel datavec-arrow diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml deleted file mode 100644 index ee029d09f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml +++ /dev/null @@ -1,143 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbor-server - jar - - deeplearning4j-nearestneighbor-server - - - 1.8 - - - - - org.deeplearning4j - deeplearning4j-nearestneighbors-model - ${project.version} - - - org.deeplearning4j - deeplearning4j-core - ${project.version} - - - io.vertx - vertx-core - ${vertx.version} - - - io.vertx - vertx-web - ${vertx.version} - - - com.mashape.unirest - unirest-java - ${unirest.version} - test - - - org.deeplearning4j - deeplearning4j-nearestneighbors-client - ${project.version} - test - - - com.beust - jcommander - ${jcommander.version} - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - - - - - - org.apache.maven.plugins - maven-surefire-plugin - - -Dfile.encoding=UTF-8 -Xmx8g - - - *.java - **/*.java - - - - - org.apache.maven.plugins - maven-compiler-plugin - - ${java.compile.version} - ${java.compile.version} - - - - - - - - test-nd4j-native - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - test-nd4j-cuda-11.0 - - - org.nd4j - nd4j-cuda-11.0 - ${project.version} - test - - - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java deleted file mode 100644 index 88f3a7b46..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.server; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.deeplearning4j.clustering.vptree.VPTree; -import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; -import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.util.ArrayList; -import java.util.List; - -@AllArgsConstructor -@Builder -public class NearestNeighbor { - private NearestNeighborRequest record; - private VPTree tree; - private INDArray points; - - public List search() { - INDArray input = points.slice(record.getInputIndex()); - List results = new ArrayList<>(); - if (input.isVector()) { - List add = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(input, record.getK(), add, distances); - - if (add.size() != distances.size()) { - throw new IllegalStateException( - String.format("add.size == %d != %d == distances.size", - add.size(), distances.size())); - } - - for (int i=0; i print the usage info - jcmdr.usage(); - if (r.ndarrayPath == null) - log.error("Json path parameter is missing (null)"); - try { - Thread.sleep(500); - } catch (Exception e2) { - } - System.exit(1); - } - - instanceArgs = r; - try { - Vertx vertx = Vertx.vertx(); - vertx.deployVerticle(NearestNeighborsServer.class.getName()); - } catch (Throwable t){ - log.error("Error in NearestNeighboursServer run method",t); - } - } - - @Override - public void start() throws Exception { - instance = this; - - String[] pathArr = instanceArgs.ndarrayPath.split(","); - //INDArray[] pointsArr = new INDArray[pathArr.length]; - // first of all we reading shapes of saved eariler files - int rows = 0; - int cols = 0; - for (int i = 0; i < pathArr.length; i++) { - DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); - - log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), - Shape.size(shape, 1)); - - if (Shape.rank(shape) != 2) - throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); - - rows += Shape.size(shape, 0); - - if (cols == 0) - cols = Shape.size(shape, 1); - else if (cols != Shape.size(shape, 1)) - throw new DL4JInvalidInputException( - "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); - } - - final List labels = new ArrayList<>(); - if (instanceArgs.labelsPath != null) { - String[] labelsPathArr = instanceArgs.labelsPath.split(","); - for (int i = 0; i < labelsPathArr.length; i++) { - labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); - } - } - if (!labels.isEmpty() && labels.size() != rows) - throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size())); - - final INDArray points = Nd4j.createUninitialized(rows, cols); - - int lastPosition = 0; - for (int i = 0; i < pathArr.length; i++) { - log.info("Loading chunk {} of {}", i + 1, pathArr.length); - INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i])); - - points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr); - lastPosition += pointsArr.rows(); - - // let's ensure we don't bring too much stuff in next loop - System.gc(); - } - - VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); - - //Set play secret key, if required - //http://www.playframework.com/documentation/latest/ApplicationSecret - String crypto = System.getProperty("play.crypto.secret"); - if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { - byte[] newCrypto = new byte[1024]; - - new Random().nextBytes(newCrypto); - - String base64 = Base64.getEncoder().encodeToString(newCrypto); - System.setProperty("play.crypto.secret", base64); - } - - Router r = Router.router(vertx); - r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all - createRoutes(r, labels, tree, points); - - vertx.createHttpServer() - .requestHandler(r) - .listen(instanceArgs.port); - } - - private void createRoutes(Router r, List labels, VPTree tree, INDArray points){ - - r.post("/knn").handler(rc -> { - try { - String json = rc.getBodyAsJson().encode(); - NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); - - NearestNeighbor nearestNeighbor = - NearestNeighbor.builder().points(points).record(record).tree(tree).build(); - - if (record == null) { - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); - return; - } - - NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); - - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(results)); - return; - } catch (Throwable e) { - log.error("Error in POST /knn",e); - rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) - .end("Error parsing request - " + e.getMessage()); - return; - } - }); - - r.post("/knnnew").handler(rc -> { - try { - String json = rc.getBodyAsJson().encode(); - Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); - if (record == null) { - rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) - .putHeader("content-type", "application/json") - .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); - return; - } - - INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); - List results; - List distances; - - if (record.isForceFillK()) { - VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr); - vpTreeFillSearch.search(); - results = vpTreeFillSearch.getResults(); - distances = vpTreeFillSearch.getDistances(); - } else { - results = new ArrayList<>(); - distances = new ArrayList<>(); - tree.search(arr, record.getK(), results, distances); - } - - if (results.size() != distances.size()) { - rc.response() - .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) - .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); - return; - } - - List nnResult = new ArrayList<>(); - for (int i=0; i results = nearestNeighbor.search(); - assertEquals(1, results.get(0).getIndex()); - assertEquals(2, results.size()); - - assertEquals(1.0, results.get(0).getDistance(), 1e-4); - assertEquals(4.0, results.get(1).getDistance(), 1e-4); - } - - @Test - public void testNearestNeighborInverted() { - double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; - INDArray arr = Nd4j.create(data); - - VPTree vpTree = new VPTree(arr, true); - NearestNeighborRequest request = new NearestNeighborRequest(); - request.setK(2); - request.setInputIndex(0); - NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); - List results = nearestNeighbor.search(); - assertEquals(2, results.get(0).getIndex()); - assertEquals(2, results.size()); - - assertEquals(-4.0, results.get(0).getDistance(), 1e-4); - assertEquals(-1.0, results.get(1).getDistance(), 1e-4); - } - - @Test - public void vpTreeTest() throws Exception { - INDArray matrix = Nd4j.rand(new int[] {400,10}); - INDArray rowVector = matrix.getRow(70); - INDArray resultArr = Nd4j.zeros(400,1); - Executor executor = Executors.newSingleThreadExecutor(); - VPTree vpTree = new VPTree(matrix); - System.out.println("Ran!"); - } - - - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - - @Test - public void testServer() throws Exception { - int localPort = getAvailablePort(); - Nd4j.getRandom().setSeed(7); - INDArray rand = Nd4j.randn(10, 5); - File writeToTmp = testDir.newFile(); - writeToTmp.deleteOnExit(); - BinarySerde.writeArrayToDisk(rand, writeToTmp); - NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", - String.valueOf(localPort)); - - Thread.sleep(3000); - - NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); - NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); - assertEquals(5, result.getResults().size()); - NearestNeighborsServer.getInstance().stop(); - } - - - - @Test - public void testFullSearch() throws Exception { - int numRows = 1000; - int numCols = 100; - int numNeighbors = 42; - INDArray points = Nd4j.rand(numRows, numCols); - VPTree tree = new VPTree(points); - INDArray query = Nd4j.rand(new int[] {1, numCols}); - VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query); - fillSearch.search(); - List results = fillSearch.getResults(); - List distances = fillSearch.getDistances(); - assertEquals(numNeighbors, distances.size()); - assertEquals(numNeighbors, results.size()); - } - - @Test - public void testDistances() { - - INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}}); - INDArray record = Nd4j.create(new float[][]{{7, 6}}); - VPTree vpTree = new VPTree(indArray, "euclidean", false); - VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record); - vpTreeFillSearch.search(); - //System.out.println(vpTreeFillSearch.getResults()); - System.out.println(vpTreeFillSearch.getDistances()); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml deleted file mode 100644 index 7e0af0fa1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml +++ /dev/null @@ -1,46 +0,0 @@ - - - - - - logs/application.log - - %logger{15} - %message%n%xException{5} - - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - \ No newline at end of file diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml deleted file mode 100644 index 55d7b83f9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml +++ /dev/null @@ -1,60 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-client - jar - - deeplearning4j-nearestneighbors-client - - - - com.mashape.unirest - unirest-java - ${unirest.version} - - - org.deeplearning4j - deeplearning4j-nearestneighbors-model - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java deleted file mode 100644 index 570e75bf9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.client; - -import com.mashape.unirest.http.ObjectMapper; -import com.mashape.unirest.http.Unirest; -import com.mashape.unirest.request.HttpRequest; -import lombok.AllArgsConstructor; -import lombok.Getter; -import lombok.Setter; -import lombok.val; -import org.deeplearning4j.nearestneighbor.model.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonProcessingException; - -import java.io.IOException; - -@AllArgsConstructor -public class NearestNeighborsClient { - - private String url; - @Setter - @Getter - protected String authToken; - - public NearestNeighborsClient(String url){ - this(url, null); - } - - static { - // Only one time - - Unirest.setObjectMapper(new ObjectMapper() { - private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = - new org.nd4j.shade.jackson.databind.ObjectMapper(); - - public T readValue(String value, Class valueType) { - try { - return jacksonObjectMapper.readValue(value, valueType); - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - public String writeValue(Object value) { - try { - return jacksonObjectMapper.writeValueAsString(value); - } catch (JsonProcessingException e) { - throw new RuntimeException(e); - } - } - }); - } - - - /** - * Runs knn on the given index - * with the given k (note that this is for data - * already within the existing dataset not new data) - * @param index the index of the - * EXISTING ndarray - * to run a search on - * @param k the number of results - * @return - * @throws Exception - */ - public NearestNeighborsResults knn(int index, int k) throws Exception { - NearestNeighborRequest request = new NearestNeighborRequest(); - request.setInputIndex(index); - request.setK(k); - val req = Unirest.post(url + "/knn"); - req.header("accept", "application/json") - .header("Content-Type", "application/json").body(request); - addAuthHeader(req); - - NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); - return ret; - } - - /** - * Run a k nearest neighbors search - * on a NEW data point - * @param k the number of results - * to retrieve - * @param arr the array to run the search on. - * Note that this must be a row vector - * @return - * @throws Exception - */ - public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception { - Base64NDArrayBody base64NDArrayBody = - Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); - - val req = Unirest.post(url + "/knnnew"); - req.header("accept", "application/json") - .header("Content-Type", "application/json").body(base64NDArrayBody); - addAuthHeader(req); - - NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); - - return ret; - } - - - /** - * Add the specified authentication header to the specified HttpRequest - * - * @param request HTTP Request to add the authentication header to - */ - protected HttpRequest addAuthHeader(HttpRequest request) { - if (authToken != null) { - request.header("authorization", "Bearer " + authToken); - } - - return request; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml deleted file mode 100644 index 09a72628e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml +++ /dev/null @@ -1,61 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-model - jar - - deeplearning4j-nearestneighbors-model - - - - org.projectlombok - lombok - ${lombok.version} - provided - - - org.nd4j - nd4j-api - ${nd4j.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java deleted file mode 100644 index c68f48ebe..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@AllArgsConstructor -@NoArgsConstructor -@Builder -public class Base64NDArrayBody implements Serializable { - private String ndarray; - private int k; - private boolean forceFillK; -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java deleted file mode 100644 index f2a9475a1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; - -@Data -@AllArgsConstructor -@Builder -@NoArgsConstructor -public class BatchRecord implements Serializable { - private List records; - - /** - * Add a record - * @param record - */ - public void add(CSVRecord record) { - if (records == null) - records = new ArrayList<>(); - records.add(record); - } - - - /** - * Return a batch record based on a dataset - * @param dataSet the dataset to get the batch record for - * @return the batch record - */ - public static BatchRecord fromDataSet(DataSet dataSet) { - BatchRecord batchRecord = new BatchRecord(); - for (int i = 0; i < dataSet.numExamples(); i++) { - batchRecord.add(CSVRecord.fromRow(dataSet.get(i))); - } - - return batchRecord; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java deleted file mode 100644 index ef642bf0d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java +++ /dev/null @@ -1,85 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.dataset.DataSet; - -import java.io.Serializable; - -@Data -@AllArgsConstructor -@NoArgsConstructor -public class CSVRecord implements Serializable { - private String[] values; - - /** - * Instantiate a csv record from a vector - * given either an input dataset and a - * one hot matrix, the index will be appended to - * the end of the record, or for regression - * it will append all values in the labels - * @param row the input vectors - * @return the record from this {@link DataSet} - */ - public static CSVRecord fromRow(DataSet row) { - if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) - throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); - if (!row.getLabels().isVector() && !row.getLabels().isScalar()) - throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); - //classification - CSVRecord record; - int idx = 0; - if (row.getLabels().sumNumber().doubleValue() == 1.0) { - String[] values = new String[row.getFeatures().columns() + 1]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - int maxIdx = 0; - for (int i = 0; i < row.getLabels().length(); i++) { - if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { - maxIdx = i; - } - } - - values[idx++] = String.valueOf(maxIdx); - record = new CSVRecord(values); - } - //regression (any number of values) - else { - String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; - for (int i = 0; i < row.getFeatures().length(); i++) { - values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); - } - for (int i = 0; i < row.getLabels().length(); i++) { - values[idx++] = String.valueOf(row.getLabels().getDouble(i)); - } - - - record = new CSVRecord(values); - - } - return record; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java deleted file mode 100644 index 5044c6b35..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.Data; - -import java.io.Serializable; - -@Data -public class NearestNeighborRequest implements Serializable { - private int k; - private int inputIndex; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java deleted file mode 100644 index 768b0dfc9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -@Data -@AllArgsConstructor -@NoArgsConstructor -public class NearestNeighborsResult { - public NearestNeighborsResult(int index, double distance) { - this(index, distance, null); - } - - private int index; - private double distance; - private String label; -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java deleted file mode 100644 index d95c68fb6..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.nearestneighbor.model; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; -import java.util.List; - -@Data -@Builder -@NoArgsConstructor -@AllArgsConstructor -public class NearestNeighborsResults implements Serializable { - private List results; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml deleted file mode 100644 index 5df85229d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml +++ /dev/null @@ -1,103 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-nearestneighbors-parent - 1.0.0-SNAPSHOT - - - nearestneighbor-core - jar - - nearestneighbor-core - - - - org.nd4j - nd4j-api - ${nd4j.version} - - - junit - junit - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-nn - ${project.version} - - - org.deeplearning4j - deeplearning4j-datasets - ${project.version} - test - - - joda-time - joda-time - 2.10.3 - test - - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - - - - - test-nd4j-native - - - org.nd4j - nd4j-native - ${project.version} - test - - - - - test-nd4j-cuda-11.0 - - - org.nd4j - nd4j-cuda-11.0 - ${project.version} - test - - - - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java deleted file mode 100755 index e7e467ad3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java +++ /dev/null @@ -1,218 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.clustering.cluster.Cluster; -import org.deeplearning4j.clustering.cluster.ClusterSet; -import org.deeplearning4j.clustering.cluster.ClusterUtils; -import org.deeplearning4j.clustering.cluster.Point; -import org.deeplearning4j.clustering.info.ClusterSetInfo; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.deeplearning4j.clustering.iteration.IterationInfo; -import org.deeplearning4j.clustering.strategy.ClusteringStrategy; -import org.deeplearning4j.clustering.strategy.ClusteringStrategyType; -import org.deeplearning4j.clustering.strategy.OptimisationStrategy; -import org.deeplearning4j.clustering.util.MultiThreadUtils; -import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.ExecutorService; - -@Slf4j -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable { - - private static final long serialVersionUID = 338231277453149972L; - - private ClusteringStrategy clusteringStrategy; - private IterationHistory iterationHistory; - private int currentIteration = 0; - private ClusterSet clusterSet; - private List initialPoints; - private transient ExecutorService exec; - private boolean useKmeansPlusPlus; - - - protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { - this.clusteringStrategy = clusteringStrategy; - this.exec = MultiThreadUtils.newExecutorService(); - this.useKmeansPlusPlus = useKmeansPlusPlus; - } - - /** - * - * @param clusteringStrategy - * @return - */ - public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { - return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus); - } - - /** - * - * @param points - * @return - */ - public ClusterSet applyTo(List points) { - resetState(points); - initClusters(useKmeansPlusPlus); - iterations(); - return clusterSet; - } - - private void resetState(List points) { - this.iterationHistory = new IterationHistory(); - this.currentIteration = 0; - this.clusterSet = null; - this.initialPoints = points; - } - - /** Run clustering iterations until a - * termination condition is hit. - * This is done by first classifying all points, - * and then updating cluster centers based on - * those classified points - */ - private void iterations() { - int iterationCount = 0; - while ((clusteringStrategy.getTerminationCondition() != null - && !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory)) - || iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) { - currentIteration++; - removePoints(); - classifyPoints(); - applyClusteringStrategy(); - log.trace("Completed clustering iteration {}", ++iterationCount); - } - } - - protected void classifyPoints() { - //Classify points. This also adds each point to the ClusterSet - ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec); - //Update the cluster centers, based on the points within each cluster - ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec); - iterationHistory.getIterationsInfos().put(currentIteration, - new IterationInfo(currentIteration, clusterSetInfo)); - } - - /** - * Initialize the - * cluster centers at random - */ - protected void initClusters(boolean kMeansPlusPlus) { - log.info("Generating initial clusters"); - List points = new ArrayList<>(initialPoints); - - //Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly) - val random = Nd4j.getRandom(); - Distance distanceFn = clusteringStrategy.getDistanceFunction(); - int initialClusterCount = clusteringStrategy.getInitialClusterCount(); - clusterSet = new ClusterSet(distanceFn, - clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()}); - clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size()))); - - - //dxs: distances between - // each point and nearest cluster to that point - INDArray dxs = Nd4j.create(points.size()); - dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE); - - //Generate the initial cluster centers, by randomly selecting a point between 0 and max distance - //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster - while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) { - dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec); - double summed = Nd4j.sum(dxs).getDouble(0); - double r = kMeansPlusPlus ? random.nextDouble() * summed: - random.nextFloat() * dxs.maxNumber().doubleValue(); - - for (int i = 0; i < dxs.length(); i++) { - double distance = dxs.getDouble(i); - Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " + - "function must return values >= 0, got distance %s for function s", distance, distanceFn); - if (dxs.getDouble(i) >= r) { - clusterSet.addNewClusterWithCenter(points.remove(i)); - dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i)); - break; - } - } - } - - ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet); - iterationHistory.getIterationsInfos().put(currentIteration, - new IterationInfo(currentIteration, initialClusterSetInfo)); - } - - - protected void applyClusteringStrategy() { - if (!isStrategyApplicableNow()) - return; - - ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); - if (!clusteringStrategy.isAllowEmptyClusters()) { - int removedCount = removeEmptyClusters(clusterSetInfo); - if (removedCount > 0) { - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); - - if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) - && clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) { - int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo, - clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec); - if (splitCount > 0) - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); - } - } - } - if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) - optimize(); - } - - protected void optimize() { - ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); - OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy; - boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec); - iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied); - } - - private boolean isStrategyApplicableNow() { - return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0 - && clusteringStrategy.isOptimizationApplicableNow(iterationHistory); - } - - protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) { - List removedClusters = clusterSet.removeEmptyClusters(); - clusterSetInfo.removeClusterInfos(removedClusters); - return removedClusters.size(); - } - - protected void removePoints() { - clusterSet.removePoints(); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java deleted file mode 100644 index 02ac17f39..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -import org.deeplearning4j.clustering.cluster.ClusterSet; -import org.deeplearning4j.clustering.cluster.Point; - -import java.util.List; - -public interface ClusteringAlgorithm { - - /** - * Apply a clustering - * algorithm for a given result - * @param points - * @return - */ - ClusterSet applyTo(List points); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java deleted file mode 100644 index 657df3dfa..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java +++ /dev/null @@ -1,41 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.algorithm; - -public enum Distance { - EUCLIDEAN("euclidean"), - COSINE_DISTANCE("cosinedistance"), - COSINE_SIMILARITY("cosinesimilarity"), - MANHATTAN("manhattan"), - DOT("dot"), - JACCARD("jaccard"), - HAMMING("hamming"); - - private String functionName; - private Distance(String name) { - functionName = name; - } - - @Override - public String toString() { - return functionName; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java deleted file mode 100644 index 8a39d8bc3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java +++ /dev/null @@ -1,105 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -public class CentersHolder { - private INDArray centers; - private long index = 0; - - protected transient ReduceOp op; - protected ArgMin imin; - protected transient INDArray distances; - protected transient INDArray argMin; - - private long rows, cols; - - public CentersHolder(long rows, long cols) { - this.rows = rows; - this.cols = cols; - } - - public INDArray getCenters() { - return this.centers; - } - - public synchronized void addCenter(INDArray pointView) { - if (centers == null) - this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols}); - - centers.putRow(index++, pointView); - } - - public synchronized Pair getCenterByMinDistance(Point point, Distance distanceFunction) { - if (distances == null) - distances = Nd4j.create(centers.dataType(), centers.rows()); - - if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); - - if (op == null) { - op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new ArgMin(distances, argMin); - op.setZ(distances); - } - - op.setY(point.getArray()); - - Nd4j.getExecutioner().exec(op); - Nd4j.getExecutioner().exec(imin); - - Pair result = new Pair<>(); - result.setFirst(distances.getDouble(argMin.getLong(0))); - result.setSecond(argMin.getLong(0)); - return result; - } - - public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) { - if (distances == null) - distances = Nd4j.create(centers.dataType(), centers.rows()); - - if (argMin == null) - argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); - - if (op == null) { - op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new ArgMin(distances, argMin); - op.setZ(distances); - } - - op.setY(point.getArray()); - - Nd4j.getExecutioner().exec(op); - Nd4j.getExecutioner().exec(imin); - - System.out.println(distances); - return distances; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java deleted file mode 100644 index 7f4f221e5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java +++ /dev/null @@ -1,150 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.Data; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.UUID; - -@Data -public class Cluster implements Serializable { - - private String id = UUID.randomUUID().toString(); - private String label; - - private Point center; - private List points = Collections.synchronizedList(new ArrayList()); - private boolean inverse = false; - private Distance distanceFunction; - - public Cluster() { - super(); - } - - /** - * - * @param center - * @param distanceFunction - */ - public Cluster(Point center, Distance distanceFunction) { - this(center, false, distanceFunction); - } - - /** - * - * @param center - * @param distanceFunction - */ - public Cluster(Point center, boolean inverse, Distance distanceFunction) { - this.distanceFunction = distanceFunction; - this.inverse = inverse; - setCenter(center); - } - - /** - * Get the distance to the given - * point from the cluster - * @param point the point to get the distance for - * @return - */ - public double getDistanceToCenter(Point point) { - return Nd4j.getExecutioner().execAndReturn( - ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray())) - .getFinalResult().doubleValue(); - } - - /** - * Add a point to the cluster - * @param point - */ - public void addPoint(Point point) { - addPoint(point, true); - } - - /** - * Add a point to the cluster - * @param point the point to add - * @param moveClusterCenter whether to update - * the cluster centroid or not - */ - public void addPoint(Point point, boolean moveClusterCenter) { - if (moveClusterCenter) { - if (isInverse()) { - center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1); - } else { - center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1); - } - } - - getPoints().add(point); - } - - /** - * Clear out the ponits - */ - public void removePoints() { - if (getPoints() != null) - getPoints().clear(); - } - - /** - * Whether the cluster is empty or not - * @return - */ - public boolean isEmpty() { - return points == null || points.isEmpty(); - } - - /** - * Return the point with the given id - * @param id - * @return - */ - public Point getPoint(String id) { - for (Point point : points) - if (id.equals(point.getId())) - return point; - return null; - } - - /** - * Remove the point and return it - * @param id - * @return - */ - public Point removePoint(String id) { - Point removePoint = null; - for (Point point : points) - if (id.equals(point.getId())) - removePoint = point; - if (removePoint != null) - points.remove(removePoint); - return removePoint; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java deleted file mode 100644 index dabfdc7a4..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java +++ /dev/null @@ -1,259 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.Data; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; -import java.util.*; - -@Data -public class ClusterSet implements Serializable { - - private Distance distanceFunction; - private List clusters; - private CentersHolder centersHolder; - private Map pointDistribution; - private boolean inverse; - - public ClusterSet(boolean inverse) { - this(null, inverse, null); - } - - public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) { - this.distanceFunction = distanceFunction; - this.inverse = inverse; - this.clusters = Collections.synchronizedList(new ArrayList()); - this.pointDistribution = Collections.synchronizedMap(new HashMap()); - if (shape != null) - this.centersHolder = new CentersHolder(shape[0], shape[1]); - } - - - public boolean isInverse() { - return inverse; - } - - /** - * - * @param center - * @return - */ - public Cluster addNewClusterWithCenter(Point center) { - Cluster newCluster = new Cluster(center, distanceFunction); - getClusters().add(newCluster); - setPointLocation(center, newCluster); - centersHolder.addCenter(center.getArray()); - return newCluster; - } - - /** - * - * @param point - * @return - */ - public PointClassification classifyPoint(Point point) { - return classifyPoint(point, true); - } - - /** - * - * @param points - */ - public void classifyPoints(List points) { - classifyPoints(points, true); - } - - /** - * - * @param points - * @param moveClusterCenter - */ - public void classifyPoints(List points, boolean moveClusterCenter) { - for (Point point : points) - classifyPoint(point, moveClusterCenter); - } - - /** - * - * @param point - * @param moveClusterCenter - * @return - */ - public PointClassification classifyPoint(Point point, boolean moveClusterCenter) { - Pair nearestCluster = nearestCluster(point); - Cluster newCluster = nearestCluster.getKey(); - boolean locationChange = isPointLocationChange(point, newCluster); - addPointToCluster(point, newCluster, moveClusterCenter); - return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange); - } - - private boolean isPointLocationChange(Point point, Cluster newCluster) { - if (!getPointDistribution().containsKey(point.getId())) - return true; - return !getPointDistribution().get(point.getId()).equals(newCluster.getId()); - } - - private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) { - cluster.addPoint(point, moveClusterCenter); - setPointLocation(point, cluster); - } - - private void setPointLocation(Point point, Cluster cluster) { - pointDistribution.put(point.getId(), cluster.getId()); - } - - - /** - * - * @param point - * @return - */ - public Pair nearestCluster(Point point) { - - /*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE; - - double currentDistance; - for (Cluster cluster : getClusters()) { - currentDistance = cluster.getDistanceToCenter(point); - if (isInverse()) { - if (currentDistance > minDistance) { - minDistance = currentDistance; - nearestCluster = cluster; - } - } else { - if (currentDistance < minDistance) { - minDistance = currentDistance; - nearestCluster = cluster; - } - } - - }*/ - - Pair nearestCenterData = centersHolder. - getCenterByMinDistance(point, distanceFunction); - Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue()); - double minDistance = nearestCenterData.getFirst(); - return Pair.of(nearestCluster, minDistance); - } - - /** - * - * @param m1 - * @param m2 - * @return - */ - public double getDistance(Point m1, Point m2) { - return Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray())) - .getFinalResult().doubleValue(); - } - - /** - * - * @param point - * @return - */ - /*public double getDistanceFromNearestCluster(Point point) { - return nearestCluster(point).getValue(); - }*/ - - - /** - * - * @param clusterId - * @return - */ - public String getClusterCenterId(String clusterId) { - Point clusterCenter = getClusterCenter(clusterId); - return clusterCenter == null ? null : clusterCenter.getId(); - } - - /** - * - * @param clusterId - * @return - */ - public Point getClusterCenter(String clusterId) { - Cluster cluster = getCluster(clusterId); - return cluster == null ? null : cluster.getCenter(); - } - - /** - * - * @param id - * @return - */ - public Cluster getCluster(String id) { - for (int i = 0, j = clusters.size(); i < j; i++) - if (id.equals(clusters.get(i).getId())) - return clusters.get(i); - return null; - } - - /** - * - * @return - */ - public int getClusterCount() { - return getClusters() == null ? 0 : getClusters().size(); - } - - /** - * - */ - public void removePoints() { - for (Cluster cluster : getClusters()) - cluster.removePoints(); - } - - /** - * - * @param count - * @return - */ - public List getMostPopulatedClusters(int count) { - List mostPopulated = new ArrayList<>(clusters); - Collections.sort(mostPopulated, new Comparator() { - public int compare(Cluster o1, Cluster o2) { - return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); - } - }); - return mostPopulated.subList(0, count); - } - - /** - * - * @return - */ - public List removeEmptyClusters() { - List emptyClusters = new ArrayList<>(); - for (Cluster cluster : clusters) - if (cluster.isEmpty()) - emptyClusters.add(cluster); - clusters.removeAll(emptyClusters); - return emptyClusters; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java deleted file mode 100644 index ac1786538..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java +++ /dev/null @@ -1,531 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.ArrayUtils; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.info.ClusterInfo; -import org.deeplearning4j.clustering.info.ClusterSetInfo; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; -import org.deeplearning4j.clustering.strategy.OptimisationStrategy; -import org.deeplearning4j.clustering.util.MathUtils; -import org.deeplearning4j.clustering.util.MultiThreadUtils; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.*; -import java.util.concurrent.ExecutorService; - -@NoArgsConstructor(access = AccessLevel.PRIVATE) -@Slf4j -public class ClusterUtils { - - /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */ - public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List points, - ExecutorService executorService) { - final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true); - - List tasks = new ArrayList<>(); - for (final Point point : points) { - //tasks.add(new Runnable() { - // public void run() { - try { - PointClassification result = classifyPoint(clusterSet, point); - if (result.isNewLocation()) - clusterSetInfo.getPointLocationChange().incrementAndGet(); - clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter() - .put(point.getId(), result.getDistanceFromCenter()); - } catch (Throwable t) { - log.warn("Error classifying point", t); - } - // } - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - return clusterSetInfo; - } - - public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) { - return clusterSet.classifyPoint(point, false); - } - - public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - ExecutorService executorService) { - List tasks = new ArrayList<>(); - int nClusters = clusterSet.getClusterCount(); - for (int i = 0; i < nClusters; i++) { - final Cluster cluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - // public void run() { - try { - final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); - refreshClusterCenter(cluster, clusterInfo); - deriveClusterInfoDistanceStatistics(clusterInfo); - } catch (Throwable t) { - log.warn("Error refreshing cluster centers", t); - } - // } - //}); - } - //MultiThreadUtils.parallelTasks(tasks, executorService); - } - - public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { - int pointsCount = cluster.getPoints().size(); - if (pointsCount == 0) - return; - Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length())); - for (Point point : cluster.getPoints()) { - INDArray arr = point.getArray(); - if (cluster.isInverse()) - center.getArray().subi(arr); - else - center.getArray().addi(arr); - } - center.getArray().divi(pointsCount); - cluster.setCenter(center); - } - - /** - * - * @param info - */ - public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) { - int pointCount = info.getPointDistancesFromCenter().size(); - if (pointCount == 0) - return; - - double[] distances = - ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {})); - double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances); - double total = MathUtils.sum(distances); - info.setMaxPointDistanceFromCenter(max); - info.setTotalPointDistanceFromCenter(total); - info.setAveragePointDistanceFromCenter(total / pointCount); - info.setPointDistanceFromCenterVariance(MathUtils.variance(distances)); - } - - /** - * - * @param clusterSet - * @param points - * @param previousDxs - * @param executorService - * @return - */ - public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet, - final List points, INDArray previousDxs, ExecutorService executorService) { - final int pointsCount = points.size(); - final INDArray dxs = Nd4j.create(pointsCount); - final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); - - List tasks = new ArrayList<>(); - for (int i = 0; i < pointsCount; i++) { - final int i2 = i; - //tasks.add(new Runnable() { - // public void run() { - try { - Point point = points.get(i2); - double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) - : Math.pow(newCluster.getDistanceToCenter(point), 2); - dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist); - } catch (Throwable t) { - log.warn("Error computing squared distance from nearest cluster", t); - } - // } - //}); - - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - for (int i = 0; i < pointsCount; i++) { - double previousMinDistance = previousDxs.getDouble(i); - if (clusterSet.isInverse()) { - if (dxs.getDouble(i) < previousMinDistance) { - - dxs.putScalar(i, previousMinDistance); - } - } else if (dxs.getDouble(i) > previousMinDistance) - dxs.putScalar(i, previousMinDistance); - } - - return dxs; - } - - public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet, - final List points, INDArray previousDxs) { - final int pointsCount = points.size(); - final INDArray dxs = Nd4j.create(pointsCount); - final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); - - Double sum = new Double(0); - for (int i = 0; i < pointsCount; i++) { - - Point point = points.get(i); - double dist = Math.pow(newCluster.getDistanceToCenter(point), 2); - sum += dist; - dxs.putScalar(i, sum); - } - - return dxs; - } - /** - * - * @param clusterSet - * @return - */ - public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) { - ExecutorService executor = MultiThreadUtils.newExecutorService(); - ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor); - executor.shutdownNow(); - return info; - } - - public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) { - final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true); - int clusterCount = clusterSet.getClusterCount(); - - List tasks = new ArrayList<>(); - for (int i = 0; i < clusterCount; i++) { - final Cluster cluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - // public void run() { - try { - info.getClustersInfos().put(cluster.getId(), - computeClusterInfos(cluster, clusterSet.getDistanceFunction())); - } catch (Throwable t) { - log.warn("Error computing cluster set info", t); - } - //} - //}); - } - - - //MultiThreadUtils.parallelTasks(tasks, executorService); - - //tasks = new ArrayList<>(); - for (int i = 0; i < clusterCount; i++) { - final int clusterIdx = i; - final Cluster fromCluster = clusterSet.getClusters().get(i); - //tasks.add(new Runnable() { - //public void run() { - try { - for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { - Cluster toCluster = clusterSet.getClusters().get(k); - double distance = Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp( - clusterSet.getDistanceFunction(), - fromCluster.getCenter().getArray(), - toCluster.getCenter().getArray())) - .getFinalResult().doubleValue(); - info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(), - distance); - } - } catch (Throwable t) { - log.warn("Error computing distances", t); - } - // } - //}); - - } - - //MultiThreadUtils.parallelTasks(tasks, executorService); - - return info; - } - - /** - * - * @param cluster - * @param distanceFunction - * @return - */ - public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) { - ClusterInfo info = new ClusterInfo(cluster.isInverse(), true); - for (int i = 0, j = cluster.getPoints().size(); i < j; i++) { - Point point = cluster.getPoints().get(i); - //shouldn't need to inverse here. other parts of - //the code should interpret the "distance" or score here - double distance = Nd4j.getExecutioner() - .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, - cluster.getCenter().getArray(), point.getArray())) - .getFinalResult().doubleValue(); - info.getPointDistancesFromCenter().put(point.getId(), distance); - double diff = info.getTotalPointDistanceFromCenter() + distance; - info.setTotalPointDistanceFromCenter(diff); - } - - if (!cluster.getPoints().isEmpty()) - info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size()); - return info; - } - - /** - * - * @param optimization - * @param clusterSet - * @param clusterSetInfo - * @param executor - * @return - */ - public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, ExecutorService executor) { - - if (optimization.isClusteringOptimizationType( - ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) { - int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, - clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); - return splitCount > 0; - } - - if (optimization.isClusteringOptimizationType( - ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) { - int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, - clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); - return splitCount > 0; - } - - return false; - } - - /** - * - * @param clusterSet - * @param info - * @param count - * @return - */ - public static List getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, - int count) { - List clusters = new ArrayList<>(clusterSet.getClusters()); - Collections.sort(clusters, new Comparator() { - public int compare(Cluster o1, Cluster o2) { - Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter(); - Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter(); - int comp = o1TotalDistance.compareTo(o2TotalDistance); - return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp; - } - }); - - return clusters.subList(0, count); - } - - /** - * - * @param clusterSet - * @param info - * @param maximumAverageDistance - * @return - */ - public static List getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet, - final ClusterSetInfo info, double maximumAverageDistance) { - List clusters = new ArrayList<>(); - for (Cluster cluster : clusterSet.getClusters()) { - ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); - if (clusterInfo != null) { - //distances - if (clusterInfo.isInverse()) { - if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance) - clusters.add(cluster); - } else { - if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance) - clusters.add(cluster); - } - - } - - } - return clusters; - } - - /** - * - * @param clusterSet - * @param info - * @param maximumDistance - * @return - */ - public static List getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet, - final ClusterSetInfo info, double maximumDistance) { - List clusters = new ArrayList<>(); - for (Cluster cluster : clusterSet.getClusters()) { - ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); - if (clusterInfo != null) { - if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) { - clusters.add(cluster); - } else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) { - clusters.add(cluster); - - } - } - } - return clusters; - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param count - * @param executorService - * @return - */ - public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, - ExecutorService executorService) { - List clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param maxWithinClusterDistance - * @param executorService - * @return - */ - public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { - List clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, - maxWithinClusterDistance); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param maxWithinClusterDistance - * @param executorService - * @return - */ - public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, - ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { - List clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, - maxWithinClusterDistance); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); - return clustersToSplit.size(); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param count - * @param executorService - */ - public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, - ExecutorService executorService) { - List clustersToSplit = clusterSet.getMostPopulatedClusters(count); - splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param clusters - * @param maxDistance - * @param executorService - */ - public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - List clusters, final double maxDistance, ExecutorService executorService) { - final Random random = new Random(); - List tasks = new ArrayList<>(); - for (final Cluster cluster : clusters) { - tasks.add(new Runnable() { - public void run() { - try { - ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); - List fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance); - int rank = Math.min(fartherPoints.size(), 3); - String pointId = fartherPoints.get(random.nextInt(rank)); - Point point = cluster.removePoint(pointId); - clusterSet.addNewClusterWithCenter(point); - } catch (Throwable t) { - log.warn("Error splitting clusters", t); - } - } - }); - } - MultiThreadUtils.parallelTasks(tasks, executorService); - } - - /** - * - * @param clusterSet - * @param clusterSetInfo - * @param clusters - * @param executorService - */ - public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, - List clusters, ExecutorService executorService) { - final Random random = new Random(); - List tasks = new ArrayList<>(); - for (final Cluster cluster : clusters) { - tasks.add(new Runnable() { - public void run() { - try { - Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size())); - clusterSet.addNewClusterWithCenter(point); - } catch (Throwable t) { - log.warn("Error Splitting clusters (2)", t); - } - } - }); - } - - MultiThreadUtils.parallelTasks(tasks, executorService); - } - - public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){ - val op = createDistanceFunctionOp(distanceFunction, x, y); - op.setDimensions(dimensions); - return op; - } - - public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){ - switch (distanceFunction){ - case COSINE_DISTANCE: - return new CosineDistance(x,y); - case COSINE_SIMILARITY: - return new CosineSimilarity(x,y); - case DOT: - return new Dot(x,y); - case EUCLIDEAN: - return new EuclideanDistance(x,y); - case JACCARD: - return new JaccardDistance(x,y); - case MANHATTAN: - return new ManhattanDistance(x,y); - default: - throw new IllegalStateException("Unknown distance function: " + distanceFunction); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java deleted file mode 100644 index 14147b004..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java +++ /dev/null @@ -1,107 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.UUID; - -/** - * - */ -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class Point implements Serializable { - - private static final long serialVersionUID = -6658028541426027226L; - - private String id = UUID.randomUUID().toString(); - private String label; - private INDArray array; - - - /** - * - * @param array - */ - public Point(INDArray array) { - super(); - this.array = array; - } - - /** - * - * @param id - * @param array - */ - public Point(String id, INDArray array) { - super(); - this.id = id; - this.array = array; - } - - public Point(String id, String label, double[] data) { - this(id, label, Nd4j.create(data)); - } - - public Point(String id, String label, INDArray array) { - super(); - this.id = id; - this.label = label; - this.array = array; - } - - - /** - * - * @param matrix - * @return - */ - public static List toPoints(INDArray matrix) { - List arr = new ArrayList<>(matrix.rows()); - for (int i = 0; i < matrix.rows(); i++) { - arr.add(new Point(matrix.getRow(i))); - } - - return arr; - } - - /** - * - * @param vectors - * @return - */ - public static List toPoints(List vectors) { - List points = new ArrayList<>(); - for (INDArray vector : vectors) - points.add(new Point(vector)); - return points; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java deleted file mode 100644 index 6951b4a03..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java +++ /dev/null @@ -1,40 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class PointClassification implements Serializable { - - private Cluster cluster; - private double distanceFromCenter; - private boolean newLocation; - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java deleted file mode 100644 index 852a58920..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java +++ /dev/null @@ -1,37 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -public interface ClusteringAlgorithmCondition { - - /** - * - * @param iterationHistory - * @return - */ - boolean isSatisfied(IterationHistory iterationHistory); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java deleted file mode 100644 index 6c2659f60..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.LessThan; - -import java.io.Serializable; - -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor(access = AccessLevel.PROTECTED) -public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition convergenceCondition; - private double pointsDistributionChangeRate; - - - /** - * - * @param pointsDistributionChangeRate - * @return - */ - public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) { - Condition condition = new LessThan(pointsDistributionChangeRate); - return new ConvergenceCondition(condition, pointsDistributionChangeRate); - } - - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - int iterationCount = iterationHistory.getIterationCount(); - if (iterationCount <= 1) - return false; - - double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get(); - variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount(); - - return convergenceCondition.apply(variation); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java deleted file mode 100644 index 7eda7a7ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual; - -import java.io.Serializable; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition iterationCountCondition; - - protected FixedIterationCountCondition(int initialClusterCount) { - iterationCountCondition = new GreaterThanOrEqual(initialClusterCount); - } - - /** - * - * @param iterationCount - * @return - */ - public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) { - return new FixedIterationCountCondition(iterationCount); - } - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount()); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java deleted file mode 100644 index ff91dd7eb..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.condition; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.nd4j.linalg.indexing.conditions.Condition; -import org.nd4j.linalg.indexing.conditions.LessThan; - -import java.io.Serializable; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable { - - private Condition varianceVariationCondition; - private int period; - - - - /** - * - * @param varianceVariation - * @param period - * @return - */ - public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) { - Condition condition = new LessThan(varianceVariation); - return new VarianceVariationCondition(condition, period); - } - - - /** - * - * @param iterationHistory - * @return - */ - public boolean isSatisfied(IterationHistory iterationHistory) { - if (iterationHistory.getIterationCount() <= period) - return false; - - for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) { - double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() - .getPointDistanceFromClusterVariance(); - - if (!varianceVariationCondition.apply(variation)) - return false; - } - - return true; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java deleted file mode 100644 index 2b78ee3e8..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java +++ /dev/null @@ -1,114 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.info; - -import lombok.Data; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.ConcurrentHashMap; - -/** - * - */ -@Data -public class ClusterInfo implements Serializable { - - private double averagePointDistanceFromCenter; - private double maxPointDistanceFromCenter; - private double pointDistanceFromCenterVariance; - private double totalPointDistanceFromCenter; - private boolean inverse; - private Map pointDistancesFromCenter = new ConcurrentHashMap<>(); - - public ClusterInfo(boolean inverse) { - this(false, inverse); - } - - /** - * - * @param threadSafe - */ - public ClusterInfo(boolean threadSafe, boolean inverse) { - super(); - this.inverse = inverse; - if (threadSafe) { - pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter); - } - } - - /** - * - * @return - */ - public Set> getSortedPointDistancesFromCenter() { - SortedSet> sortedEntries = new TreeSet<>(new Comparator>() { - @Override - public int compare(Map.Entry e1, Map.Entry e2) { - int res = e1.getValue().compareTo(e2.getValue()); - return res != 0 ? res : 1; - } - }); - sortedEntries.addAll(pointDistancesFromCenter.entrySet()); - return sortedEntries; - } - - /** - * - * @return - */ - public Set> getReverseSortedPointDistancesFromCenter() { - SortedSet> sortedEntries = new TreeSet<>(new Comparator>() { - @Override - public int compare(Map.Entry e1, Map.Entry e2) { - int res = e1.getValue().compareTo(e2.getValue()); - return -(res != 0 ? res : 1); - } - }); - sortedEntries.addAll(pointDistancesFromCenter.entrySet()); - return sortedEntries; - } - - /** - * - * @param maxDistance - * @return - */ - public List getPointsFartherFromCenterThan(double maxDistance) { - Set> sorted = getReverseSortedPointDistancesFromCenter(); - List ids = new ArrayList<>(); - for (Map.Entry entry : sorted) { - if (inverse && entry.getValue() < -maxDistance) { - if (entry.getValue() < -maxDistance) - break; - } - - else if (entry.getValue() > maxDistance) - break; - - ids.add(entry.getKey()); - } - return ids; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java deleted file mode 100644 index 3ddfd1b25..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.info; - -import org.nd4j.shade.guava.collect.HashBasedTable; -import org.nd4j.shade.guava.collect.Table; -import org.deeplearning4j.clustering.cluster.Cluster; -import org.deeplearning4j.clustering.cluster.ClusterSet; - -import java.io.Serializable; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; - -public class ClusterSetInfo implements Serializable { - - private Map clustersInfos = new HashMap<>(); - private Table distancesBetweenClustersCenters = HashBasedTable.create(); - private AtomicInteger pointLocationChange; - private boolean threadSafe; - private boolean inverse; - - public ClusterSetInfo(boolean inverse) { - this(inverse, false); - } - - /** - * - * @param inverse - * @param threadSafe - */ - public ClusterSetInfo(boolean inverse, boolean threadSafe) { - this.pointLocationChange = new AtomicInteger(0); - this.threadSafe = threadSafe; - this.inverse = inverse; - if (threadSafe) { - clustersInfos = Collections.synchronizedMap(clustersInfos); - } - } - - - /** - * - * @param clusterSet - * @param threadSafe - * @return - */ - public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) { - ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe); - for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++) - info.addClusterInfo(clusterSet.getClusters().get(i).getId()); - return info; - } - - public void removeClusterInfos(List clusters) { - for (Cluster cluster : clusters) { - clustersInfos.remove(cluster.getId()); - } - } - - public ClusterInfo addClusterInfo(String clusterId) { - ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe); - clustersInfos.put(clusterId, clusterInfo); - return clusterInfo; - } - - public ClusterInfo getClusterInfo(String clusterId) { - return clustersInfos.get(clusterId); - } - - public double getAveragePointDistanceFromClusterCenter() { - if (clustersInfos == null || clustersInfos.isEmpty()) - return 0; - - double average = 0; - for (ClusterInfo info : clustersInfos.values()) - average += info.getAveragePointDistanceFromCenter(); - return average / clustersInfos.size(); - } - - public double getPointDistanceFromClusterVariance() { - if (clustersInfos == null || clustersInfos.isEmpty()) - return 0; - - double average = 0; - for (ClusterInfo info : clustersInfos.values()) - average += info.getPointDistanceFromCenterVariance(); - return average / clustersInfos.size(); - } - - public int getPointsCount() { - int count = 0; - for (ClusterInfo clusterInfo : clustersInfos.values()) - count += clusterInfo.getPointDistancesFromCenter().size(); - return count; - } - - public Map getClustersInfos() { - return clustersInfos; - } - - public void setClustersInfos(Map clustersInfos) { - this.clustersInfos = clustersInfos; - } - - public Table getDistancesBetweenClustersCenters() { - return distancesBetweenClustersCenters; - } - - public void setDistancesBetweenClustersCenters(Table interClusterDistances) { - this.distancesBetweenClustersCenters = interClusterDistances; - } - - public AtomicInteger getPointLocationChange() { - return pointLocationChange; - } - - public void setPointLocationChange(AtomicInteger pointLocationChange) { - this.pointLocationChange = pointLocationChange; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java deleted file mode 100644 index 0854e5eb1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.iteration; - -import lombok.Getter; -import lombok.Setter; -import org.deeplearning4j.clustering.info.ClusterSetInfo; - -import java.io.Serializable; -import java.util.HashMap; -import java.util.Map; - -public class IterationHistory implements Serializable { - @Getter - @Setter - private Map iterationsInfos = new HashMap<>(); - - /** - * - * @return - */ - public ClusterSetInfo getMostRecentClusterSetInfo() { - IterationInfo iterationInfo = getMostRecentIterationInfo(); - return iterationInfo == null ? null : iterationInfo.getClusterSetInfo(); - } - - /** - * - * @return - */ - public IterationInfo getMostRecentIterationInfo() { - return getIterationInfo(getIterationCount() - 1); - } - - /** - * - * @return - */ - public int getIterationCount() { - return getIterationsInfos().size(); - } - - /** - * - * @param iterationIdx - * @return - */ - public IterationInfo getIterationInfo(int iterationIdx) { - return getIterationsInfos().get(iterationIdx); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java deleted file mode 100644 index 0036f3c47..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.iteration; - -import lombok.AccessLevel; -import lombok.Data; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.info.ClusterSetInfo; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class IterationInfo implements Serializable { - - private int index; - private ClusterSetInfo clusterSetInfo; - private boolean strategyApplied; - - public IterationInfo(int index) { - super(); - this.index = index; - } - - public IterationInfo(int index, ClusterSetInfo clusterSetInfo) { - super(); - this.index = index; - this.clusterSetInfo = clusterSetInfo; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java deleted file mode 100644 index c3e0bc418..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.custom.KnnMinDistance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; - -public class HyperRect implements Serializable { - - //private List points; - private float[] lowerEnds; - private float[] higherEnds; - private INDArray lowerEndsIND; - private INDArray higherEndsIND; - - public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { - this.lowerEnds = new float[lowerEndsIn.length]; - this.higherEnds = new float[lowerEndsIn.length]; - System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); - System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); - lowerEndsIND = Nd4j.createFromArray(lowerEnds); - higherEndsIND = Nd4j.createFromArray(higherEnds); - } - - public HyperRect(float[] point) { - this(point, point); - } - - public HyperRect(Pair ends) { - this(ends.getFirst(), ends.getSecond()); - } - - - public void enlargeTo(INDArray point) { - float[] pointAsArray = point.toFloatVector(); - for (int i = 0; i < lowerEnds.length; i++) { - float p = pointAsArray[i]; - if (lowerEnds[i] > p) - lowerEnds[i] = p; - else if (higherEnds[i] < p) - higherEnds[i] = p; - } - } - - public static Pair point(INDArray vector) { - Pair ret = new Pair<>(); - float[] curr = new float[(int)vector.length()]; - for (int i = 0; i < vector.length(); i++) { - curr[i] = vector.getFloat(i); - } - ret.setFirst(curr); - ret.setSecond(curr); - return ret; - } - - - /*public List contains(INDArray hPoint) { - List ret = new ArrayList<>(); - for (int i = 0; i < hPoint.length(); i++) { - ret.add(lowerEnds[i] <= hPoint.getDouble(i) && - higherEnds[i] >= hPoint.getDouble(i)); - } - return ret; - }*/ - - public double minDistance(INDArray hPoint, INDArray output) { - Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); - return output.getFloat(0); - - /*double ret = 0.0; - double[] pointAsArray = hPoint.toDoubleVector(); - for (int i = 0; i < pointAsArray.length; i++) { - double p = pointAsArray[i]; - if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { - if (p < lowerEnds[i]) - ret += Math.pow((p - lowerEnds[i]), 2); - else - ret += Math.pow((p - higherEnds[i]), 2); - } - } - ret = Math.pow(ret, 0.5); - return ret;*/ - } - - public HyperRect getUpper(INDArray hPoint, int desc) { - //Interval interval = points.get(desc); - float higher = higherEnds[desc]; - float d = hPoint.getFloat(desc); - if (higher < d) - return null; - HyperRect ret = new HyperRect(lowerEnds,higherEnds); - if (ret.lowerEnds[desc] < d) - ret.lowerEnds[desc] = d; - return ret; - } - - public HyperRect getLower(INDArray hPoint, int desc) { - //Interval interval = points.get(desc); - float lower = lowerEnds[desc]; - float d = hPoint.getFloat(desc); - if (lower > d) - return null; - HyperRect ret = new HyperRect(lowerEnds,higherEnds); - //Interval i2 = ret.points.get(desc); - if (ret.higherEnds[desc] > d) - ret.higherEnds[desc] = d; - return ret; - } - - @Override - public String toString() { - String retVal = ""; - retVal += "["; - for (int i = 0; i < lowerEnds.length; ++i) { - retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") "; - } - retVal += "]"; - return retVal; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java deleted file mode 100644 index fd77c8342..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java +++ /dev/null @@ -1,370 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; - -public class KDTree implements Serializable { - - private KDNode root; - private int dims = 100; - public final static int GREATER = 1; - public final static int LESS = 0; - private int size = 0; - private HyperRect rect; - - public KDTree(int dims) { - this.dims = dims; - } - - /** - * Insert a point in to the tree - * @param point the point to insert - */ - public void insert(INDArray point) { - if (!point.isVector() || point.length() != dims) - throw new IllegalArgumentException("Point must be a vector of length " + dims); - - if (root == null) { - root = new KDNode(point); - rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); - } else { - int disc = 0; - KDNode node = root; - KDNode insert = new KDNode(point); - int successor; - while (true) { - //exactly equal - INDArray pt = node.getPoint(); - INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z(); - if (countEq.getInt(0) == 0) { - return; - } else { - successor = successor(node, point, disc); - KDNode child; - if (successor < 1) - child = node.getLeft(); - else - child = node.getRight(); - if (child == null) - break; - disc = (disc + 1) % dims; - node = child; - } - } - - if (successor < 1) - node.setLeft(insert); - - else - node.setRight(insert); - - rect.enlargeTo(point); - insert.setParent(node); - } - size++; - - } - - - public INDArray delete(INDArray point) { - KDNode node = root; - int _disc = 0; - while (node != null) { - if (node.point == point) - break; - int successor = successor(node, point, _disc); - if (successor < 1) - node = node.getLeft(); - else - node = node.getRight(); - _disc = (_disc + 1) % dims; - } - - if (node != null) { - if (node == root) { - root = delete(root, _disc); - } else - node = delete(node, _disc); - size--; - if (size == 1) { - rect = new HyperRect(HyperRect.point(point)); - } else if (size == 0) - rect = null; - - } - return node.getPoint(); - } - - // Share this data for recursive calls of "knn" - private float currentDistance; - private INDArray currentPoint; - private INDArray minDistance = Nd4j.scalar(0.f); - - - public List> knn(INDArray point, float distance) { - List> best = new ArrayList<>(); - currentDistance = distance; - currentPoint = point; - knn(root, rect, best, 0); - Collections.sort(best, new Comparator>() { - @Override - public int compare(Pair o1, Pair o2) { - return Float.compare(o1.getKey(), o2.getKey()); - } - }); - - return best; - } - - - private void knn(KDNode node, HyperRect rect, List> best, int _disc) { - if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) - return; - int _discNext = (_disc + 1) % dims; - float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() - .floatValue(); - - if (distance <= currentDistance) { - best.add(Pair.of(distance, node.getPoint())); - } - - HyperRect lower = rect.getLower(node.point, _disc); - HyperRect upper = rect.getUpper(node.point, _disc); - knn(node.getLeft(), lower, best, _discNext); - knn(node.getRight(), upper, best, _discNext); - } - - /** - * Query for nearest neighbor. Returns the distance and point - * @param point the point to query for - * @return - */ - public Pair nn(INDArray point) { - return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0); - } - - - private Pair nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, - int _disc) { - if (node == null || rect.minDistance(point, minDistance) > dist) - return Pair.of(Double.POSITIVE_INFINITY, null); - - int _discNext = (_disc + 1) % dims; - double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue(); - if (dist2 < dist) { - best = node.getPoint(); - dist = dist2; - } - - HyperRect lower = rect.getLower(node.point, _disc); - HyperRect upper = rect.getUpper(node.point, _disc); - - if (point.getDouble(_disc) < node.point.getDouble(_disc)) { - Pair left = nn(node.getLeft(), point, lower, dist, best, _discNext); - Pair right = nn(node.getRight(), point, upper, dist, best, _discNext); - if (left.getKey() < dist) - return left; - else if (right.getKey() < dist) - return right; - - } else { - Pair left = nn(node.getRight(), point, upper, dist, best, _discNext); - Pair right = nn(node.getLeft(), point, lower, dist, best, _discNext); - if (left.getKey() < dist) - return left; - else if (right.getKey() < dist) - return right; - } - - return Pair.of(dist, best); - - } - - private KDNode delete(KDNode delete, int _disc) { - if (delete.getLeft() != null && delete.getRight() != null) { - if (delete.getParent() != null) { - if (delete.getParent().getLeft() == delete) - delete.getParent().setLeft(null); - else - delete.getParent().setRight(null); - - } - return null; - } - - int disc = _disc; - _disc = (_disc + 1) % dims; - Pair qd = null; - if (delete.getRight() != null) { - qd = min(delete.getRight(), disc, _disc); - } else if (delete.getLeft() != null) - qd = max(delete.getLeft(), disc, _disc); - if (qd == null) {// is leaf - return null; - } - delete.point = qd.getKey().point; - KDNode qFather = qd.getKey().getParent(); - if (qFather.getLeft() == qd.getKey()) { - qFather.setLeft(delete(qd.getKey(), disc)); - } else if (qFather.getRight() == qd.getKey()) { - qFather.setRight(delete(qd.getKey(), disc)); - - } - - return delete; - - - } - - - private Pair max(KDNode node, int disc, int _disc) { - int discNext = (_disc + 1) % dims; - if (_disc == disc) { - KDNode child = node.getLeft(); - if (child != null) { - return max(child, disc, discNext); - } - } else if (node.getLeft() != null || node.getRight() != null) { - Pair left = null, right = null; - if (node.getLeft() != null) - left = max(node.getLeft(), disc, discNext); - if (node.getRight() != null) - right = max(node.getRight(), disc, discNext); - if (left != null && right != null) { - double pointLeft = left.getKey().getPoint().getDouble(disc); - double pointRight = right.getKey().getPoint().getDouble(disc); - if (pointLeft > pointRight) - return left; - else - return right; - } else if (left != null) - return left; - else - return right; - } - - return Pair.of(node, _disc); - } - - - - private Pair min(KDNode node, int disc, int _disc) { - int discNext = (_disc + 1) % dims; - if (_disc == disc) { - KDNode child = node.getLeft(); - if (child != null) { - return min(child, disc, discNext); - } - } else if (node.getLeft() != null || node.getRight() != null) { - Pair left = null, right = null; - if (node.getLeft() != null) - left = min(node.getLeft(), disc, discNext); - if (node.getRight() != null) - right = min(node.getRight(), disc, discNext); - if (left != null && right != null) { - double pointLeft = left.getKey().getPoint().getDouble(disc); - double pointRight = right.getKey().getPoint().getDouble(disc); - if (pointLeft < pointRight) - return left; - else - return right; - } else if (left != null) - return left; - else - return right; - } - - return Pair.of(node, _disc); - } - - /** - * The number of elements in the tree - * @return the number of elements in the tree - */ - public int size() { - return size; - } - - private int successor(KDNode node, INDArray point, int disc) { - for (int i = disc; i < dims; i++) { - double pointI = point.getDouble(i); - double nodePointI = node.getPoint().getDouble(i); - if (pointI < nodePointI) - return LESS; - else if (pointI > nodePointI) - return GREATER; - - } - - throw new IllegalStateException("Point is equal!"); - } - - - private static class KDNode { - private INDArray point; - private KDNode left, right, parent; - - public KDNode(INDArray point) { - this.point = point; - } - - public INDArray getPoint() { - return point; - } - - public KDNode getLeft() { - return left; - } - - public void setLeft(KDNode left) { - this.left = left; - } - - public KDNode getRight() { - return right; - } - - public void setRight(KDNode right) { - this.right = right; - } - - public KDNode getParent() { - return parent; - } - - public void setParent(KDNode parent) { - this.parent = parent; - } - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java deleted file mode 100755 index 00b5bb3e9..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java +++ /dev/null @@ -1,109 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kmeans; - -import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.strategy.ClusteringStrategy; -import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy; - - -public class KMeansClustering extends BaseClusteringAlgorithm { - - private static final long serialVersionUID = 8476951388145944776L; - private static final double VARIATION_TOLERANCE= 1e-4; - - - /** - * - * @param clusteringStrategy - */ - protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) { - super(clusteringStrategy, useKMeansPlusPlus); - } - - /** - * Setup a kmeans instance - * @param clusterCount the number of clusters - * @param maxIterationCount the max number of iterations - * to run kmeans - * @param distanceFunction the distance function to use for grouping - * @return - */ - public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, - boolean inverse, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = - FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse); - clusteringStrategy.endWhenIterationCountEquals(maxIterationCount); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - /** - * - * @param clusterCount - * @param minDistributionVariationRate - * @param distanceFunction - * @param allowEmptyClusters - * @return - */ - public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, - boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse) - .endWhenDistributionVariationRateLessThan(minDistributionVariationRate); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - - /** - * Setup a kmeans instance - * @param clusterCount the number of clusters - * @param maxIterationCount the max number of iterations - * to run kmeans - * @param distanceFunction the distance function to use for grouping - * @return - */ - public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) { - return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus); - } - - /** - * - * @param clusterCount - * @param minDistributionVariationRate - * @param distanceFunction - * @param allowEmptyClusters - * @return - */ - public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, - boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); - clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - - public static KMeansClustering setup(int clusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean useKMeansPlusPlus) { - ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); - clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE); - return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java deleted file mode 100644 index b9fbffa7a..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java +++ /dev/null @@ -1,88 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import org.nd4j.linalg.api.ndarray.INDArray; - -public interface LSH { - - /** - * Returns an instance of the distance measure associated to the LSH family of this implementation. - * Beware, hashing families and their amplification constructs are distance-specific. - */ - String getDistanceMeasure(); - - /** - * Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction - * - * denoting hashLength by h, - * amplifies a (d1, d2, p1, p2) hash family into a - * (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h) - * - * @return the length of the hash in the AND construction used by this index - */ - int getHashLength(); - - /** - * - * denoting numTables by n, - * amplifies a (d1, d2, p1, p2) hash family into a - * (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n) - * - * @return the # of hash tables in the OR construction used by this index - */ - int getNumTables(); - - /** - * @return The dimension of the index vectors and queries - */ - int getInDimension(); - - /** - * Populates the index with data vectors. - * @param data the vectors to index - */ - void makeIndex(INDArray data); - - /** - * Returns the set of all vectors that could approximately be considered negihbors of the query, - * without selection on the basis of distance or number of neighbors. - * @param query a vector to find neighbors for - * @return its approximate neighbors, unfiltered - */ - INDArray bucket(INDArray query); - - /** - * Returns the approximate neighbors within a distance bound. - * @param query a vector to find neighbors for - * @param maxRange the maximum distance between results and the query - * @return approximate neighbors within the distance bounds - */ - INDArray search(INDArray query, double maxRange); - - /** - * Returns the approximate neighbors within a k-closest bound - * @param query a vector to find neighbors for - * @param k the maximum number of closest neighbors to return - * @return at most k neighbors of the query, ordered by increasing distance - */ - INDArray search(INDArray query, int k); -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java deleted file mode 100644 index 7b9873d73..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java +++ /dev/null @@ -1,227 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import lombok.Getter; -import lombok.val; -import org.nd4j.common.base.Preconditions; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; -import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; -import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; -import org.nd4j.linalg.api.rng.Random; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.indexing.BooleanIndexing; -import org.nd4j.linalg.indexing.conditions.Conditions; -import org.nd4j.linalg.ops.transforms.Transforms; - -import java.util.Arrays; - - -public class RandomProjectionLSH implements LSH { - - @Override - public String getDistanceMeasure(){ - return "cosinedistance"; - } - - @Getter private int hashLength; - - @Getter private int numTables; - - @Getter private int inDimension; - - - @Getter private double radius; - - INDArray randomProjection; - - INDArray index; - - INDArray indexData; - - - private INDArray gaussianRandomMatrix(int[] shape, Random rng){ - INDArray res = Nd4j.create(shape); - - GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0])); - - Nd4j.getExecutioner().exec(op1, rng); - return res; - } - - public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){ - this(hashLength, numTables, inDimension, radius, Nd4j.getRandom()); - } - - /** - * Creates a locality-sensitive hashing index for the cosine distance, - * a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification - * - * @param hashLength the length of the compared hash in an AND construction, - * @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple - * probes of Panigraphi (op. cit). - * @param inDimension the dimendionality of the points being indexed - * @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction - * @param rng a Random object to draw samples from - */ - public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){ - this.hashLength = hashLength; - this.numTables = numTables; - this.inDimension = inDimension; - this.radius = radius; - randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng); - } - - /** - * This picks uniformaly distributed random points on the unit of a sphere using the method of: - * - * An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere - * JS Hicks, RF Wheeling - Communications of the ACM, 1959 - * @param data a query to generate multiple probes for - * @return `numTables` - */ - public INDArray entropy(INDArray data){ - - INDArray data2 = - Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius)); - - INDArray norms = Nd4j.norm2(data2.dup(), -1); - - Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms); - - data2.diviColumnVector(norms); - data2.addiRowVector(data); - return data2; - } - - /** - * Returns hash values for a particular query - * @param data a query vector - * @return its hashed value - */ - public INDArray hash(INDArray data) { - if (data.shape()[1] != inDimension){ - throw new ND4JIllegalStateException( - String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d", - Arrays.toString(data.shape()), inDimension)); - } - INDArray projected = data.mmul(randomProjection); - INDArray res = Nd4j.getExecutioner().exec(new Sign(projected)); - return res; - } - - /** - * Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it. - * @param data the vectors to index - */ - @Override - public void makeIndex(INDArray data) { - index = hash(data); - indexData = data; - } - - // data elements in the same bucket as the query, without entropy - INDArray rawBucketOf(INDArray query){ - INDArray pattern = hash(query); - - INDArray res = Nd4j.zeros(DataType.BOOL, index.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1)); - return res.castTo(Nd4j.defaultFloatingPointType()).min(-1); - } - - @Override - public INDArray bucket(INDArray query) { - INDArray queryRes = rawBucketOf(query); - - if(numTables > 1) { - INDArray entropyQueries = entropy(query); - - // loop, addi + conditionalreplace -> poor man's OR function - for (int i = 0; i < numTables; i++) { - INDArray row = entropyQueries.getRow(i, true); - queryRes.addi(rawBucketOf(row)); - } - BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0)); - } - - return queryRes; - } - - // data elements in the same entropy bucket as the query, - INDArray bucketData(INDArray query){ - INDArray mask = bucket(query); - int nRes = mask.sum(0).getInt(0); - INDArray res = Nd4j.create(new int[] {nRes, inDimension}); - int j = 0; - for (int i = 0; i < nRes; i++){ - while (mask.getInt(j) == 0 && j < mask.length() - 1) { - j += 1; - } - if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j)); - j += 1; - } - return res; - } - - @Override - public INDArray search(INDArray query, double maxRange) { - if (maxRange < 0) - throw new IllegalArgumentException("ANN search should have a positive maximum search radius"); - - INDArray bucketData = bucketData(query); - INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); - INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); - - INDArray shuffleIndexes = idxs[0]; - INDArray sortedDistances = idxs[1]; - int accepted = 0; - while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1; - - INDArray res = Nd4j.create(new int[] {accepted, inDimension}); - for(int i = 0; i < accepted; i++){ - res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); - } - return res; - } - - @Override - public INDArray search(INDArray query, int k) { - if (k < 1) - throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor"); - - INDArray bucketData = bucketData(query); - INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); - INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); - - INDArray shuffleIndexes = idxs[0]; - INDArray sortedDistances = idxs[1]; - val accepted = Math.min(k, sortedDistances.shape()[1]); - - INDArray res = Nd4j.create(accepted, inDimension); - for(int i = 0; i < accepted; i++){ - res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); - } - return res; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java deleted file mode 100644 index b65571de3..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.optimisation; - -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; - -import java.io.Serializable; - -@Data -@NoArgsConstructor(access = AccessLevel.PROTECTED) -@AllArgsConstructor -public class ClusteringOptimization implements Serializable { - - private ClusteringOptimizationType type; - private double value; - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java deleted file mode 100644 index a2220010e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.optimisation; - -/** - * - */ -public enum ClusteringOptimizationType { - MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java deleted file mode 100644 index cb82b6f87..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java +++ /dev/null @@ -1,115 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -public class Cell implements Serializable { - private double x, y, hw, hh; - - public Cell(double x, double y, double hw, double hh) { - this.x = x; - this.y = y; - this.hw = hw; - this.hh = hh; - } - - /** - * Whether the given point is contained - * within this cell - * @param point the point to check - * @return true if the point is contained, false otherwise - */ - public boolean containsPoint(INDArray point) { - double first = point.getDouble(0), second = point.getDouble(1); - return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second; - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (!(o instanceof Cell)) - return false; - - Cell cell = (Cell) o; - - if (Double.compare(cell.hh, hh) != 0) - return false; - if (Double.compare(cell.hw, hw) != 0) - return false; - if (Double.compare(cell.x, x) != 0) - return false; - return Double.compare(cell.y, y) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - temp = Double.doubleToLongBits(x); - result = (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(y); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(hw); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - temp = Double.doubleToLongBits(hh); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - public double getX() { - return x; - } - - public void setX(double x) { - this.x = x; - } - - public double getY() { - return y; - } - - public void setY(double y) { - this.y = y; - } - - public double getHw() { - return hw; - } - - public void setHw(double hw) { - this.hw = hw; - } - - public double getHh() { - return hh; - } - - public void setHh(double hh) { - this.hh = hh; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java deleted file mode 100644 index 20d216b44..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java +++ /dev/null @@ -1,383 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; -import org.apache.commons.math3.util.FastMath; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; - -import static java.lang.Math.max; - -public class QuadTree implements Serializable { - private QuadTree parent, northWest, northEast, southWest, southEast; - private boolean isLeaf = true; - private int size, cumSize; - private Cell boundary; - static final int QT_NO_DIMS = 2; - static final int QT_NODE_CAPACITY = 1; - private INDArray buf = Nd4j.create(QT_NO_DIMS); - private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS); - private int[] index = new int[QT_NODE_CAPACITY]; - - - /** - * Pass in a matrix - * @param data - */ - public QuadTree(INDArray data) { - INDArray meanY = data.mean(0); - INDArray minY = data.min(0); - INDArray maxY = data.max(0); - init(data, meanY.getDouble(0), meanY.getDouble(1), - max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0)) - + Nd4j.EPS_THRESHOLD, - max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1)) - + Nd4j.EPS_THRESHOLD); - fill(); - } - - public QuadTree(QuadTree parent, INDArray data, Cell boundary) { - this.parent = parent; - this.boundary = boundary; - this.data = data; - - } - - public QuadTree(Cell boundary) { - this.boundary = boundary; - } - - private void init(INDArray data, double x, double y, double hw, double hh) { - boundary = new Cell(x, y, hw, hh); - this.data = data; - } - - private void fill() { - for (int i = 0; i < data.rows(); i++) - insert(i); - } - - - - /** - * Returns the cell of this element - * - * @param coordinates - * @return - */ - protected QuadTree findIndex(INDArray coordinates) { - - // Compute the sector for the coordinates - boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2)); - boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2)); - - // top left - QuadTree index = getNorthWest(); - if (left) { - // left side - if (!top) { - // bottom left - index = getSouthWest(); - } - } else { - // right side - if (top) { - // top right - index = getNorthEast(); - } else { - // bottom right - index = getSouthEast(); - - } - } - - return index; - } - - - /** - * Insert an index of the data in to the tree - * @param newIndex the index to insert in to the tree - * @return whether the index was inserted or not - */ - public boolean insert(int newIndex) { - // Ignore objects which do not belong in this quad tree - INDArray point = data.slice(newIndex); - if (!boundary.containsPoint(point)) - return false; - - cumSize++; - double mult1 = (double) (cumSize - 1) / (double) cumSize; - double mult2 = 1.0 / (double) cumSize; - - centerOfMass.muli(mult1); - centerOfMass.addi(point.mul(mult2)); - - // If there is space in this quad tree and it is a leaf, add the object here - if (isLeaf() && size < QT_NODE_CAPACITY) { - index[size] = newIndex; - size++; - return true; - } - - //duplicate point - if (size > 0) { - for (int i = 0; i < size; i++) { - INDArray compPoint = data.slice(index[i]); - if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1)) - return true; - } - } - - - - // If this Node has already been subdivided just add the elements to the - // appropriate cell - if (!isLeaf()) { - QuadTree index = findIndex(point); - index.insert(newIndex); - return true; - } - - if (isLeaf()) - subDivide(); - - boolean ret = insertIntoOneOf(newIndex); - - - - return ret; - } - - private boolean insertIntoOneOf(int index) { - boolean success = false; - success = northWest.insert(index); - if (!success) - success = northEast.insert(index); - if (!success) - success = southWest.insert(index); - if (!success) - success = southEast.insert(index); - return success; - } - - - /** - * Returns whether the tree is consistent or not - * @return whether the tree is consistent or not - */ - public boolean isCorrect() { - - for (int n = 0; n < size; n++) { - INDArray point = data.slice(index[n]); - if (!boundary.containsPoint(point)) - return false; - } - - return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect() - && southEast.isCorrect(); - - } - - - - /** - * Create four children - * which fully divide this cell - * into four quads of equal area - */ - public void subDivide() { - northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), - boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), - boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), - boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), - boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); - - } - - - /** - * Compute non edge forces using barnes hut - * @param pointIndex - * @param theta - * @param negativeForce - * @param sumQ - */ - public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { - // Make sure that we spend no time on empty nodes or self-interactions - if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) - return; - - - // Compute distance between point and center-of-mass - buf.assign(data.slice(pointIndex)).subi(centerOfMass); - - double D = Nd4j.getBlasWrapper().dot(buf, buf); - - // Check whether we can use this node as a "summary" - if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) { - - // Compute and add t-SNE force between point and current node - double Q = 1.0 / (1.0 + D); - double mult = cumSize * Q; - sumQ.addAndGet(mult); - mult *= Q; - negativeForce.addi(buf.mul(mult)); - - } else { - - // Recursively apply Barnes-Hut to children - northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - } - } - - - - /** - * - * @param rowP a vector - * @param colP - * @param valP - * @param N - * @param posF - */ - public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { - if (!rowP.isVector()) - throw new IllegalArgumentException("RowP must be a vector"); - - // Loop over all edges in the graph - double D; - for (int n = 0; n < N; n++) { - for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { - - // Compute pairwise distance and Q-value - buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i))); - - D = Nd4j.getBlasWrapper().dot(buf, buf); - D = valP.getDouble(i) / D; - - // Sum positive force - posF.slice(n).addi(buf.mul(D)); - - } - } - } - - - /** - * The depth of the node - * @return the depth of the node - */ - public int depth() { - if (isLeaf()) - return 1; - return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth())); - } - - public INDArray getCenterOfMass() { - return centerOfMass; - } - - public void setCenterOfMass(INDArray centerOfMass) { - this.centerOfMass = centerOfMass; - } - - public QuadTree getParent() { - return parent; - } - - public void setParent(QuadTree parent) { - this.parent = parent; - } - - public QuadTree getNorthWest() { - return northWest; - } - - public void setNorthWest(QuadTree northWest) { - this.northWest = northWest; - } - - public QuadTree getNorthEast() { - return northEast; - } - - public void setNorthEast(QuadTree northEast) { - this.northEast = northEast; - } - - public QuadTree getSouthWest() { - return southWest; - } - - public void setSouthWest(QuadTree southWest) { - this.southWest = southWest; - } - - public QuadTree getSouthEast() { - return southEast; - } - - public void setSouthEast(QuadTree southEast) { - this.southEast = southEast; - } - - public boolean isLeaf() { - return isLeaf; - } - - public void setLeaf(boolean isLeaf) { - this.isLeaf = isLeaf; - } - - public int getSize() { - return size; - } - - public void setSize(int size) { - this.size = size; - } - - public int getCumSize() { - return cumSize; - } - - public void setCumSize(int cumSize) { - this.cumSize = cumSize; - } - - public Cell getBoundary() { - return boundary; - } - - public void setBoundary(Cell boundary) { - this.boundary = boundary; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java deleted file mode 100644 index f814025d5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java +++ /dev/null @@ -1,104 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.ArrayList; -import java.util.List; - -/** - * - */ -@Data -public class RPForest { - - private int numTrees; - private List trees; - private INDArray data; - private int maxSize = 1000; - private String similarityFunction; - - /** - * Create the rp forest with the specified number of trees - * @param numTrees the number of trees in the forest - * @param maxSize the max size of each tree - * @param similarityFunction the distance function to use - */ - public RPForest(int numTrees,int maxSize,String similarityFunction) { - this.numTrees = numTrees; - this.maxSize = maxSize; - this.similarityFunction = similarityFunction; - trees = new ArrayList<>(numTrees); - - } - - - /** - * Build the trees from the given dataset - * @param x the input dataset (should be a 2d matrix) - */ - public void fit(INDArray x) { - this.data = x; - for(int i = 0; i < numTrees; i++) { - RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction); - tree.buildTree(x); - trees.add(tree); - } - } - - /** - * Get all candidates relative to a specific datapoint. - * @param input - * @return - */ - public INDArray getAllCandidates(INDArray input) { - return RPUtils.getAllCandidates(input,trees,similarityFunction); - } - - /** - * Query results up to length n - * nearest neighbors - * @param toQuery the query item - * @param n the number of nearest neighbors for the given data point - * @return the indices for the nearest neighbors - */ - public INDArray queryAll(INDArray toQuery,int n) { - return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction); - } - - - /** - * Query all with the distances - * sorted by index - * @param query the query vector - * @param numResults the number of results to return - * @return a list of samples - */ - public List> queryWithDistances(INDArray query, int numResults) { - return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java deleted file mode 100644 index 979013797..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -@Data -public class RPHyperPlanes { - private int dim; - private INDArray wholeHyperPlane; - - public RPHyperPlanes(int dim) { - this.dim = dim; - } - - public INDArray getHyperPlaneAt(int depth) { - if(wholeHyperPlane.isVector()) - return wholeHyperPlane; - return wholeHyperPlane.slice(depth); - } - - - /** - * Add a new random element to the hyper plane. - */ - public void addRandomHyperPlane() { - INDArray newPlane = Nd4j.randn(new int[] {1,dim}); - newPlane.divi(newPlane.normmaxNumber()); - if(wholeHyperPlane == null) - wholeHyperPlane = newPlane; - else { - wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane); - } - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java deleted file mode 100644 index 9a103469e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - - -import lombok.Data; - -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.Future; - -@Data -public class RPNode { - private int depth; - private RPNode left,right; - private Future leftFuture,rightFuture; - private List indices; - private double median; - private RPTree tree; - - - public RPNode(RPTree tree,int depth) { - this.depth = depth; - this.tree = tree; - indices = new ArrayList<>(); - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java deleted file mode 100644 index 7fbca2b90..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java +++ /dev/null @@ -1,130 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import lombok.Builder; -import lombok.Data; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.common.primitives.Pair; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.concurrent.ExecutorService; - -@Data -public class RPTree { - private RPNode root; - private RPHyperPlanes rpHyperPlanes; - private int dim; - //also knows as leave size - private int maxSize; - private INDArray X; - private String similarityFunction = "euclidean"; - private WorkspaceConfiguration workspaceConfiguration; - private ExecutorService searchExecutor; - private int searchWorkers; - - /** - * - * @param dim the dimension of the vectors - * @param maxSize the max size of the leaves - * - */ - @Builder - public RPTree(int dim, int maxSize,String similarityFunction) { - this.dim = dim; - this.maxSize = maxSize; - rpHyperPlanes = new RPHyperPlanes(dim); - root = new RPNode(this,0); - this.similarityFunction = similarityFunction; - workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) - .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE).build(); - - } - - /** - * - * @param dim the dimension of the vectors - * @param maxSize the max size of the leaves - * - */ - public RPTree(int dim, int maxSize) { - this(dim,maxSize,"euclidean"); - } - - /** - * Build the tree with the given input data - * @param x - */ - - public void buildTree(INDArray x) { - this.X = x; - for(int i = 0; i < x.rows(); i++) { - root.getIndices().add(i); - } - - - - RPUtils.buildTree(this,root,rpHyperPlanes, - x,maxSize,0,similarityFunction); - } - - - - public void addNodeAtIndex(int idx,INDArray toAdd) { - RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction); - query.getIndices().add(idx); - } - - - public List getLeaves() { - List nodes = new ArrayList<>(); - RPUtils.scanForLeaves(nodes,getRoot()); - return nodes; - } - - - /** - * Query all with the distances - * sorted by index - * @param query the query vector - * @param numResults the number of results to return - * @return a list of samples - */ - public List> queryWithDistances(INDArray query, int numResults) { - return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction); - } - - public INDArray query(INDArray query,int numResults) { - return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction); - } - - public List getCandidates(INDArray target) { - return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction); - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java deleted file mode 100644 index 0bd2574e7..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java +++ /dev/null @@ -1,481 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.nd4j.shade.guava.primitives.Doubles; -import lombok.val; -import org.nd4j.autodiff.functions.DifferentialFunction; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.exception.ND4JIllegalArgumentException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; - -import java.util.*; - -public class RPUtils { - - - private static ThreadLocal> functionInstances = new ThreadLocal<>(); - - public static DifferentialFunction getOp(String name, - INDArray x, - INDArray y, - INDArray result) { - Map ops = functionInstances.get(); - if(ops == null) { - ops = new HashMap<>(); - functionInstances.set(ops); - } - - boolean allDistances = x.length() != y.length(); - - switch(name) { - case "cosinedistance": - if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances); - ops.put(name,cosineDistance); - return cosineDistance; - } - else { - CosineDistance cosineDistance = (CosineDistance) ops.get(name); - return cosineDistance; - } - case "cosinesimilarity": - if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) { - CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances); - ops.put(name,cosineSimilarity); - return cosineSimilarity; - } - else { - CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name); - cosineSimilarity.setX(x); - cosineSimilarity.setY(y); - cosineSimilarity.setZ(result); - return cosineSimilarity; - - } - case "manhattan": - if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances); - ops.put(name,manhattanDistance); - return manhattanDistance; - } - else { - ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name); - manhattanDistance.setX(x); - manhattanDistance.setY(y); - manhattanDistance.setZ(result); - return manhattanDistance; - } - case "jaccard": - if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances); - ops.put(name,jaccardDistance); - return jaccardDistance; - } - else { - JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name); - jaccardDistance.setX(x); - jaccardDistance.setY(y); - jaccardDistance.setZ(result); - return jaccardDistance; - } - case "hamming": - if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances); - ops.put(name,hammingDistance); - return hammingDistance; - } - else { - HammingDistance hammingDistance = (HammingDistance) ops.get(name); - hammingDistance.setX(x); - hammingDistance.setY(y); - hammingDistance.setZ(result); - return hammingDistance; - } - //euclidean - default: - if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { - EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances); - ops.put(name,euclideanDistance); - return euclideanDistance; - } - else { - EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name); - euclideanDistance.setX(x); - euclideanDistance.setY(y); - euclideanDistance.setZ(result); - return euclideanDistance; - } - } - } - - - /** - * Query all trees using the given input and data - * @param toQuery the query vector - * @param X the input data to query - * @param trees the trees to query - * @param n the number of results to search for - * @param similarityFunction the similarity function to use - * @return the indices (in order) in the ndarray - */ - public static List> queryAllWithDistances(INDArray toQuery,INDArray X,List trees,int n,String similarityFunction) { - if(trees.isEmpty()) { - throw new ND4JIllegalArgumentException("Trees is empty!"); - } - - List candidates = getCandidates(toQuery, trees,similarityFunction); - val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); - int numReturns = Math.min(n,sortedCandidates.size()); - List> ret = new ArrayList<>(numReturns); - for(int i = 0; i < numReturns; i++) { - ret.add(sortedCandidates.get(i)); - } - - return ret; - } - - /** - * Query all trees using the given input and data - * @param toQuery the query vector - * @param X the input data to query - * @param trees the trees to query - * @param n the number of results to search for - * @param similarityFunction the similarity function to use - * @return the indices (in order) in the ndarray - */ - public static INDArray queryAll(INDArray toQuery,INDArray X,List trees,int n,String similarityFunction) { - if(trees.isEmpty()) { - throw new ND4JIllegalArgumentException("Trees is empty!"); - } - - List candidates = getCandidates(toQuery, trees,similarityFunction); - val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); - int numReturns = Math.min(n,sortedCandidates.size()); - - INDArray result = Nd4j.create(numReturns); - for(int i = 0; i < numReturns; i++) { - result.putScalar(i,sortedCandidates.get(i).getSecond()); - } - - - return result; - } - - /** - * Get the sorted distances given the - * query vector, input data, given the list of possible search candidates - * @param x the query vector - * @param X the input data to use - * @param candidates the possible search candidates - * @param similarityFunction the similarity function to use - * @return the sorted distances - */ - public static List> sortCandidates(INDArray x,INDArray X, - List candidates, - String similarityFunction) { - int prevIdx = -1; - List> ret = new ArrayList<>(); - for(int i = 0; i < candidates.size(); i++) { - if(candidates.get(i) != prevIdx) { - ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i))); - } - - prevIdx = i; - } - - - Collections.sort(ret, new Comparator>() { - @Override - public int compare(Pair doubleIntegerPair, Pair t1) { - return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst()); - } - }); - - return ret; - } - - - - /** - * Get the search candidates as indices given the input - * and similarity function - * @param x the input data to search with - * @param trees the trees to search - * @param similarityFunction the function to use for similarity - * @return the list of indices as the search results - */ - public static INDArray getAllCandidates(INDArray x,List trees,String similarityFunction) { - List candidates = getCandidates(x,trees,similarityFunction); - Collections.sort(candidates); - - int prevIdx = -1; - int idxCount = 0; - List> scores = new ArrayList<>(); - for(int i = 0; i < candidates.size(); i++) { - if(candidates.get(i) == prevIdx) { - idxCount++; - } - else if(prevIdx != -1) { - scores.add(Pair.of(idxCount,prevIdx)); - idxCount = 1; - } - - prevIdx = i; - } - - - scores.add(Pair.of(idxCount,prevIdx)); - - INDArray arr = Nd4j.create(scores.size()); - for(int i = 0; i < scores.size(); i++) { - arr.putScalar(i,scores.get(i).getSecond()); - } - - return arr; - } - - - /** - * Get the search candidates as indices given the input - * and similarity function - * @param x the input data to search with - * @param roots the trees to search - * @param similarityFunction the function to use for similarity - * @return the list of indices as the search results - */ - public static List getCandidates(INDArray x,List roots,String similarityFunction) { - Set ret = new LinkedHashSet<>(); - for(RPTree tree : roots) { - RPNode root = tree.getRoot(); - RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction); - ret.addAll(query.getIndices()); - } - - return new ArrayList<>(ret); - } - - - /** - * Query the tree starting from the given node - * using the given hyper plane and similarity function - * @param from the node to start from - * @param planes the hyper plane to query - * @param x the input data - * @param similarityFunction the similarity function to use - * @return the leaf node representing the given query from a - * search in the tree - */ - public static RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) { - if(from.getLeft() == null && from.getRight() == null) { - return from; - } - - INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth()); - double dist = computeDistance(similarityFunction,x,hyperPlane); - if(dist <= from.getMedian()) { - return query(from.getLeft(),planes,x,similarityFunction); - } - - else { - return query(from.getRight(),planes,x,similarityFunction); - } - - } - - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) { - ReduceOp op = (ReduceOp) getOp(function, x, y, result); - op.setDimensions(1); - Nd4j.getExecutioner().exec(op); - return op.z(); - } - - /** - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) { - ReduceOp op = (ReduceOp) getOp(function, x, y, result); - Nd4j.getExecutioner().exec(op); - return op.z().getDouble(0); - } - - /** - * Compute the distance between 2 vectors - * given a function name. Valid function names: - * euclidean: euclidean distance - * cosinedistance: cosine distance - * cosine similarity: cosine similarity - * manhattan: manhattan distance - * jaccard: jaccard distance - * hamming: hamming distance - * @param function the function to use (default euclidean distance) - * @param x the first vector - * @param y the second vector - * @return the distance between the 2 vectors given the inputs - */ - public static double computeDistance(String function,INDArray x,INDArray y) { - return computeDistance(function,x,y,Nd4j.scalar(0.0)); - } - - /** - * Initialize the tree given the input parameters - * @param tree the tree to initialize - * @param from the starting node - * @param planes the hyper planes to use (vector space for similarity) - * @param X the input data - * @param maxSize the max number of indices on a given leaf node - * @param depth the current depth of the tree - * @param similarityFunction the similarity function to use - */ - public static void buildTree(RPTree tree, - RPNode from, - RPHyperPlanes planes, - INDArray X, - int maxSize, - int depth, - String similarityFunction) { - if(from.getIndices().size() <= maxSize) { - //slimNode - slimNode(from); - return; - } - - - List distances = new ArrayList<>(); - RPNode left = new RPNode(tree,depth + 1); - RPNode right = new RPNode(tree,depth + 1); - - if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) { - planes.addRandomHyperPlane(); - } - - - INDArray hyperPlane = planes.getHyperPlaneAt(depth); - - - - for(int i = 0; i < from.getIndices().size(); i++) { - double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); - distances.add(cosineSim); - } - - Collections.sort(distances); - from.setMedian(distances.get(distances.size() / 2)); - - - for(int i = 0; i < from.getIndices().size(); i++) { - double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); - if(cosineSim <= from.getMedian()) { - left.getIndices().add(from.getIndices().get(i)); - } - else { - right.getIndices().add(from.getIndices().get(i)); - } - } - - //failed split - if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) { - slimNode(from); - return; - } - - - from.setLeft(left); - from.setRight(right); - slimNode(from); - - - buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction); - buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction); - - } - - - /** - * Scan for leaves accumulating - * the nodes in the passed in list - * @param nodes the nodes so far - * @param scan the tree to scan - */ - public static void scanForLeaves(List nodes,RPTree scan) { - scanForLeaves(nodes,scan.getRoot()); - } - - /** - * Scan for leaves accumulating - * the nodes in the passed in list - * @param nodes the nodes so far - */ - public static void scanForLeaves(List nodes,RPNode current) { - if(current.getLeft() == null && current.getRight() == null) - nodes.add(current); - if(current.getLeft() != null) - scanForLeaves(nodes,current.getLeft()); - if(current.getRight() != null) - scanForLeaves(nodes,current.getRight()); - } - - - /** - * Prune indices from the given node - * when it's a leaf - * @param node the node to prune - */ - public static void slimNode(RPNode node) { - if(node.getRight() != null && node.getLeft() != null) { - node.getIndices().clear(); - } - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java deleted file mode 100644 index c89e72ab1..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/Cell.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class Cell implements Serializable { - private int dimension; - private INDArray corner, width; - - public Cell(int dimension) { - this.dimension = dimension; - } - - public double corner(int d) { - return corner.getDouble(d); - } - - public double width(int d) { - return width.getDouble(d); - } - - public void setCorner(int d, double corner) { - this.corner.putScalar(d, corner); - } - - public void setWidth(int d, double width) { - this.width.putScalar(d, width); - } - - public void setWidth(INDArray width) { - this.width = width; - } - - public void setCorner(INDArray corner) { - this.corner = corner; - } - - - public boolean contains(INDArray point) { - INDArray cornerMinusWidth = corner.sub(width); - INDArray cornerPlusWidth = corner.add(width); - for (int d = 0; d < dimension; d++) { - double pointD = point.getDouble(d); - if (cornerMinusWidth.getDouble(d) > pointD) - return false; - if (cornerPlusWidth.getDouble(d) < pointD) - return false; - } - return true; - - } - - public INDArray width() { - return width; - } - - public INDArray corner() { - return corner; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java deleted file mode 100644 index 6681d3148..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/DataPoint.java +++ /dev/null @@ -1,95 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; -import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; -import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; - -@Data -public class DataPoint implements Serializable { - private int index; - private INDArray point; - private long d; - private String functionName; - private boolean invert = false; - - - public DataPoint(int index, INDArray point, boolean invert) { - this(index, point, "euclidean"); - this.invert = invert; - } - - public DataPoint(int index, INDArray point, String functionName, boolean invert) { - this.index = index; - this.point = point; - this.functionName = functionName; - this.d = point.length(); - this.invert = invert; - } - - - public DataPoint(int index, INDArray point) { - this(index, point, false); - } - - public DataPoint(int index, INDArray point, String functionName) { - this(index, point, functionName, false); - } - - /** - * Euclidean distance - * @param point the distance from this point to the given point - * @return the distance between the two points - */ - public float distance(DataPoint point) { - switch (functionName) { - case "euclidean": - float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret : ret; - - case "cosinesimilarity": - float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret2 : ret2; - - case "manhattan": - float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret3 : ret3; - case "dot": - float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point); - return invert ? -dotRet : dotRet; - default: - float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) - .getFinalResult().floatValue(); - return invert ? -ret4 : ret4; - - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java deleted file mode 100644 index a5ea6ea95..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapItem.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import java.io.Serializable; - -/** - * @author Adam Gibson - */ -public class HeapItem implements Serializable, Comparable { - private int index; - private double distance; - - - public HeapItem(int index, double distance) { - this.index = index; - this.distance = distance; - } - - public int getIndex() { - return index; - } - - public void setIndex(int index) { - this.index = index; - } - - public double getDistance() { - return distance; - } - - public void setDistance(double distance) { - this.distance = distance; - } - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - HeapItem heapItem = (HeapItem) o; - - if (index != heapItem.index) - return false; - return Double.compare(heapItem.distance, distance) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - result = index; - temp = Double.doubleToLongBits(distance); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - @Override - public int compareTo(HeapItem o) { - return distance < o.distance ? 1 : 0; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java deleted file mode 100644 index e68cf33ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/HeapObject.java +++ /dev/null @@ -1,72 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import lombok.Data; -import org.nd4j.linalg.api.ndarray.INDArray; - -import java.io.Serializable; - -@Data -public class HeapObject implements Serializable, Comparable { - private int index; - private INDArray point; - private double distance; - - - public HeapObject(int index, INDArray point, double distance) { - this.index = index; - this.point = point; - this.distance = distance; - } - - - @Override - public boolean equals(Object o) { - if (this == o) - return true; - if (o == null || getClass() != o.getClass()) - return false; - - HeapObject heapObject = (HeapObject) o; - - if (!point.equals(heapObject.point)) - return false; - - return Double.compare(heapObject.distance, distance) == 0; - - } - - @Override - public int hashCode() { - int result; - long temp; - result = index; - temp = Double.doubleToLongBits(distance); - result = 31 * result + (int) (temp ^ (temp >>> 32)); - return result; - } - - @Override - public int compareTo(HeapObject o) { - return distance < o.distance ? 1 : 0; - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java deleted file mode 100644 index 4a1bf34e4..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/sptree/SpTree.java +++ /dev/null @@ -1,425 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; -import lombok.val; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.nn.conf.WorkspaceMode; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Set; - - -/** - * @author Adam Gibson - */ -public class SpTree implements Serializable { - - - public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL"; - - - private int D; - private INDArray data; - public final static int NODE_RATIO = 8000; - private int N; - private int size; - private int cumSize; - private Cell boundary; - private INDArray centerOfMass; - private SpTree parent; - private int[] index; - private int nodeCapacity; - private int numChildren = 2; - private boolean isLeaf = true; - private Collection indices; - private SpTree[] children; - private static Logger log = LoggerFactory.getLogger(SpTree.class); - private String similarityFunction = Distance.EUCLIDEAN.toString(); - - - - public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices, - String similarityFunction) { - init(parent, data, corner, width, indices, similarityFunction); - } - - - public SpTree(INDArray data, Collection indices, String similarityFunction) { - this.indices = indices; - this.N = data.rows(); - this.D = data.columns(); - this.similarityFunction = similarityFunction; - data = data.dup(); - INDArray meanY = data.mean(0); - INDArray minY = data.min(0); - INDArray maxY = data.max(0); - INDArray width = Nd4j.create(data.dataType(), meanY.shape()); - for (int i = 0; i < width.length(); i++) { - width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i), - meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD); - } - - try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { - init(null, data, meanY, width, indices, similarityFunction); - fill(N); - } - } - - - public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices) { - this(parent, data, corner, width, indices, "euclidean"); - } - - - public SpTree(INDArray data, Collection indices) { - this(data, indices, "euclidean"); - } - - - - public SpTree(INDArray data) { - this(data, new ArrayList()); - } - - public MemoryWorkspace workspace() { - return null; - } - - private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection indices, - String similarityFunction) { - - this.parent = parent; - D = data.columns(); - N = data.rows(); - this.similarityFunction = similarityFunction; - nodeCapacity = N % NODE_RATIO; - index = new int[nodeCapacity]; - for (int d = 1; d < this.D; d++) - numChildren *= 2; - this.indices = indices; - isLeaf = true; - size = 0; - cumSize = 0; - children = new SpTree[numChildren]; - this.data = data; - boundary = new Cell(D); - boundary.setCorner(corner.dup()); - boundary.setWidth(width.dup()); - centerOfMass = Nd4j.create(data.dataType(), D); - } - - - - private boolean insert(int index) { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - INDArray point = data.slice(index); - /*boolean contains = false; - SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains); - Nd4j.getExecutioner().exec(op); - op.getOutputArgument(0).getScalar(0); - if (!contains) return false;*/ - if (!boundary.contains(point)) - return false; - - - cumSize++; - double mult1 = (double) (cumSize - 1) / (double) cumSize; - double mult2 = 1.0 / (double) cumSize; - centerOfMass.muli(mult1); - centerOfMass.addi(point.mul(mult2)); - // If there is space in this quad tree and it is a leaf, add the object here - if (isLeaf() && size < nodeCapacity) { - this.index[size] = index; - indices.add(point); - size++; - return true; - } - - - for (int i = 0; i < size; i++) { - INDArray compPoint = data.slice(this.index[i]); - if (compPoint.equals(point)) - return true; - } - - - if (isLeaf()) - subDivide(); - - - // Find out where the point can be inserted - for (int i = 0; i < numChildren; i++) { - if (children[i].insert(index)) - return true; - } - - throw new IllegalStateException("Shouldn't reach this state"); - } - } - - - /** - * Subdivide the node in to - * 4 children - */ - public void subDivide() { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{ - - INDArray newCorner = Nd4j.create(data.dataType(), D); - INDArray newWidth = Nd4j.create(data.dataType(), D); - for (int i = 0; i < numChildren; i++) { - int div = 1; - for (int d = 0; d < D; d++) { - newWidth.putScalar(d, .5 * boundary.width(d)); - if ((i / div) % 2 == 1) - newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d)); - else - newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d)); - div *= 2; - } - - children[i] = new SpTree(this, data, newCorner, newWidth, indices); - - } - - // Move existing points to correct children - for (int i = 0; i < size; i++) { - boolean success = false; - for (int j = 0; j < this.numChildren; j++) - if (!success) - success = children[j].insert(index[i]); - - index[i] = -1; - } - - // Empty parent node - size = 0; - isLeaf = false; - } - } - - - - /** - * Compute non edge forces using barnes hut - * @param pointIndex - * @param theta - * @param negativeForce - * @param sumQ - */ - public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { - // Make sure that we spend no time on empty nodes or self-interactions - INDArray buf = Nd4j.create(data.dataType(), this.D); - - if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) - return; - /* MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - // Compute distance between point and center-of-mass - data.slice(pointIndex).subi(centerOfMass, buf); - - double D = Nd4j.getBlasWrapper().dot(buf, buf); - // Check whether we can use this node as a "summary" - double maxWidth = boundary.width().maxNumber().doubleValue(); - // Check whether we can use this node as a "summary" - if (isLeaf() || maxWidth / Math.sqrt(D) < theta) { - - // Compute and add t-SNE force between point and current node - double Q = 1.0 / (1.0 + D); - double mult = cumSize * Q; - sumQ.addAndGet(mult); - mult *= Q; - negativeForce.addi(buf.mul(mult)); - } else { - - // Recursively apply Barnes-Hut to children - for (int i = 0; i < numChildren; i++) { - children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); - } - - } - } - } - - - /** - * - * Compute edge forces using barnes hut - * @param rowP a vector - * @param colP - * @param valP - * @param N the number of elements - * @param posF the positive force - */ - public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { - if (!rowP.isVector()) - throw new IllegalArgumentException("RowP must be a vector"); - - // Loop over all edges in the graph - // just execute native op - Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF)); - - /* - INDArray buf = Nd4j.create(data.dataType(), this.D); - double D; - for (int n = 0; n < N; n++) { - INDArray slice = data.slice(n); - for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { - - // Compute pairwise distance and Q-value - slice.subi(data.slice(colP.getInt(i)), buf); - - D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf); - D = valP.getDouble(i) / D; - - // Sum positive force - posF.slice(n).addi(buf.muli(D)); - } - } - */ - } - - - - public boolean isLeaf() { - return isLeaf; - } - - /** - * Verifies the structure of the tree (does bounds checking on each node) - * @return true if the structure of the tree - * is correct. - */ - public boolean isCorrect() { - /*MemoryWorkspace workspace = - workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() - : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( - workspaceConfigurationExternal, - workspaceExternal); - try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { - - for (int n = 0; n < size; n++) { - INDArray point = data.slice(index[n]); - if (!boundary.contains(point)) - return false; - } - if (!isLeaf()) { - boolean correct = true; - for (int i = 0; i < numChildren; i++) - correct = correct && children[i].isCorrect(); - return correct; - } - - return true; - } - } - - /** - * The depth of the node - * @return the depth of the node - */ - public int depth() { - if (isLeaf()) - return 1; - int depth = 1; - int maxChildDepth = 0; - for (int i = 0; i < numChildren; i++) { - maxChildDepth = Math.max(maxChildDepth, children[0].depth()); - } - - return depth + maxChildDepth; - } - - private void fill(int n) { - if (indices.isEmpty() && parent == null) - for (int i = 0; i < n; i++) { - log.trace("Inserted " + i); - insert(i); - } - else - log.warn("Called fill already"); - } - - - public SpTree[] getChildren() { - return children; - } - - public int getD() { - return D; - } - - public INDArray getCenterOfMass() { - return centerOfMass; - } - - public Cell getBoundary() { - return boundary; - } - - public int[] getIndex() { - return index; - } - - public int getCumSize() { - return cumSize; - } - - public void setCumSize(int cumSize) { - this.cumSize = cumSize; - } - - public int getNumChildren() { - return numChildren; - } - - public void setNumChildren(int numChildren) { - this.numChildren = numChildren; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java deleted file mode 100644 index daada687f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/BaseClusteringStrategy.java +++ /dev/null @@ -1,117 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import lombok.*; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.condition.ConvergenceCondition; -import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; - -import java.io.Serializable; - -@AllArgsConstructor(access = AccessLevel.PROTECTED) -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable { - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringStrategyType type; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected Integer initialClusterCount; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringAlgorithmCondition optimizationPhaseCondition; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected ClusteringAlgorithmCondition terminationCondition; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected boolean inverse; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected Distance distanceFunction; - @Getter(AccessLevel.PUBLIC) - @Setter(AccessLevel.PROTECTED) - protected boolean allowEmptyClusters; - - public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean inverse) { - this.type = type; - this.initialClusterCount = initialClusterCount; - this.distanceFunction = distanceFunction; - this.allowEmptyClusters = allowEmptyClusters; - this.inverse = inverse; - } - - public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount, - Distance distanceFunction, boolean inverse) { - this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse); - } - - - /** - * - * @param maxIterationCount - * @return - */ - public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) { - setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount)); - return this; - } - - /** - * - * @param rate - * @return - */ - public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) { - setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate)); - return this; - } - - /** - * @return - */ - @Override - public boolean inverseDistanceCalculation() { - return inverse; - } - - /** - * - * @param type - * @return - */ - public boolean isStrategyOfType(ClusteringStrategyType type) { - return type.equals(this.type); - } - - /** - * - * @return - */ - public Integer getInitialClusterCount() { - return initialClusterCount; - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java deleted file mode 100644 index 2ec9fcd47..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategy.java +++ /dev/null @@ -1,102 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -public interface ClusteringStrategy { - - /** - * - * @return - */ - boolean inverseDistanceCalculation(); - - /** - * - * @return - */ - ClusteringStrategyType getType(); - - /** - * - * @param type - * @return - */ - boolean isStrategyOfType(ClusteringStrategyType type); - - /** - * - * @return - */ - Integer getInitialClusterCount(); - - /** - * - * @return - */ - Distance getDistanceFunction(); - - /** - * - * @return - */ - boolean isAllowEmptyClusters(); - - /** - * - * @return - */ - ClusteringAlgorithmCondition getTerminationCondition(); - - /** - * - * @return - */ - boolean isOptimizationDefined(); - - /** - * - * @param iterationHistory - * @return - */ - boolean isOptimizationApplicableNow(IterationHistory iterationHistory); - - /** - * - * @param maxIterationCount - * @return - */ - BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount); - - /** - * - * @param rate - * @return - */ - BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate); - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java deleted file mode 100644 index 9f72bba95..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/ClusteringStrategyType.java +++ /dev/null @@ -1,25 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -public enum ClusteringStrategyType { - FIXED_CLUSTER_COUNT, OPTIMIZATION -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java deleted file mode 100644 index 18eceb34f..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/FixedClusterCountStrategy.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import lombok.AccessLevel; -import lombok.NoArgsConstructor; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.iteration.IterationHistory; - -/** - * - */ -@NoArgsConstructor(access = AccessLevel.PROTECTED) -public class FixedClusterCountStrategy extends BaseClusteringStrategy { - - - protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction, - boolean allowEmptyClusters, boolean inverse) { - super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters, - inverse); - } - - /** - * - * @param clusterCount - * @param distanceFunction - * @param inverse - * @return - */ - public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) { - return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse); - } - - /** - * @return - */ - @Override - public boolean inverseDistanceCalculation() { - return inverse; - } - - public boolean isOptimizationDefined() { - return false; - } - - public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { - return false; - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java deleted file mode 100644 index dc9385296..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/strategy/OptimisationStrategy.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.strategy; - -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; -import org.deeplearning4j.clustering.condition.ConvergenceCondition; -import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; -import org.deeplearning4j.clustering.iteration.IterationHistory; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimization; -import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; - -public class OptimisationStrategy extends BaseClusteringStrategy { - public static int defaultIterationCount = 100; - - private ClusteringOptimization clusteringOptimisation; - private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition; - - protected OptimisationStrategy() { - super(); - } - - protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) { - super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false); - } - - public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) { - return new OptimisationStrategy(initialClusterCount, distanceFunction); - } - - public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) { - clusteringOptimisation = new ClusteringOptimization(type, value); - return this; - } - - public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) { - clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value); - return this; - } - - public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) { - clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate); - return this; - } - - - public double getClusteringOptimizationValue() { - return clusteringOptimisation.getValue(); - } - - public boolean isClusteringOptimizationType(ClusteringOptimizationType type) { - return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type); - } - - public boolean isOptimizationDefined() { - return clusteringOptimisation != null; - } - - public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { - return clusteringOptimisationApplicationCondition != null - && clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java deleted file mode 100755 index 2290c6269..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MathUtils.java +++ /dev/null @@ -1,1327 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - - -import org.apache.commons.math3.linear.CholeskyDecomposition; -import org.apache.commons.math3.linear.NonSquareMatrixException; -import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.random.RandomGenerator; -import org.apache.commons.math3.util.FastMath; -import org.nd4j.common.primitives.Counter; - -import java.util.ArrayList; -import java.util.List; -import java.util.Random; -import java.util.Set; - - -public class MathUtils { - - /** The natural logarithm of 2. */ - public static double log2 = Math.log(2); - - /** - * Normalize a value - * (val - min) / (max - min) - * @param val value to normalize - * @param max max value - * @param min min value - * @return the normalized value - */ - public static double normalize(double val, double min, double max) { - if (max < min) - throw new IllegalArgumentException("Max must be greater than min"); - - return (val - min) / (max - min); - } - - /** - * Clamps the value to a discrete value - * @param value the value to clamp - * @param min min for the probability distribution - * @param max max for the probability distribution - * @return the discrete value - */ - public static int clamp(int value, int min, int max) { - if (value < min) - value = min; - if (value > max) - value = max; - return value; - } - - /** - * Discretize the given value - * @param value the value to discretize - * @param min the min of the distribution - * @param max the max of the distribution - * @param binCount the number of bins - * @return the discretized value - */ - public static int discretize(double value, double min, double max, int binCount) { - int discreteValue = (int) (binCount * normalize(value, min, max)); - return clamp(discreteValue, 0, binCount - 1); - } - - - /** - * See: https://stackoverflow.com/questions/466204/rounding-off-to-nearest-power-of-2 - * @param v the number to getFromOrigin the next power of 2 for - * @return the next power of 2 for the passed in value - */ - public static long nextPowOf2(long v) { - v--; - v |= v >> 1; - v |= v >> 2; - v |= v >> 4; - v |= v >> 8; - v |= v >> 16; - v++; - return v; - - } - - - - /** - * Generates a binomial distributed number using - * the given rng - * @param rng - * @param n - * @param p - * @return - */ - public static int binomial(RandomGenerator rng, int n, double p) { - if ((p < 0) || (p > 1)) { - return 0; - } - int c = 0; - for (int i = 0; i < n; i++) { - if (rng.nextDouble() < p) { - c++; - } - } - return c; - } - - /** - * Generate a uniform random number from the given rng - * @param rng the rng to use - * @param min the min num - * @param max the max num - * @return a number uniformly distributed between min and max - */ - public static double uniform(Random rng, double min, double max) { - return rng.nextDouble() * (max - min) + min; - } - - /** - * Returns the correlation coefficient of two double vectors. - * - * @param residuals residuals - * @param targetAttribute target attribute vector - * - * @return the correlation coefficient or r - */ - public static double correlation(double[] residuals, double targetAttribute[]) { - double[] predictedValues = new double[residuals.length]; - for (int i = 0; i < predictedValues.length; i++) { - predictedValues[i] = targetAttribute[i] - residuals[i]; - } - double ssErr = ssError(predictedValues, targetAttribute); - double total = ssTotal(residuals, targetAttribute); - return 1 - (ssErr / total); - }//end correlation - - /** - * 1 / 1 + exp(-x) - * @param x - * @return - */ - public static double sigmoid(double x) { - return 1.0 / (1.0 + FastMath.exp(-x)); - } - - - /** - * How much of the variance is explained by the regression - * @param residuals error - * @param targetAttribute data for target attribute - * @return the sum squares of regression - */ - public static double ssReg(double[] residuals, double[] targetAttribute) { - double mean = sum(targetAttribute) / targetAttribute.length; - double ret = 0; - for (int i = 0; i < residuals.length; i++) { - ret += Math.pow(residuals[i] - mean, 2); - } - return ret; - } - - /** - * How much of the variance is NOT explained by the regression - * @param predictedValues predicted values - * @param targetAttribute data for target attribute - * @return the sum squares of regression - */ - public static double ssError(double[] predictedValues, double[] targetAttribute) { - double ret = 0; - for (int i = 0; i < predictedValues.length; i++) { - ret += Math.pow(targetAttribute[i] - predictedValues[i], 2); - } - return ret; - - } - - - /** - * Calculate string similarity with tfidf weights relative to each character - * frequency and how many times a character appears in a given string - * @param strings the strings to calculate similarity for - * @return the cosine similarity between the strings - */ - public static double stringSimilarity(String... strings) { - if (strings == null) - return 0; - Counter counter = new Counter<>(); - Counter counter2 = new Counter<>(); - - for (int i = 0; i < strings[0].length(); i++) - counter.incrementCount(String.valueOf(strings[0].charAt(i)), 1.0f); - - for (int i = 0; i < strings[1].length(); i++) - counter2.incrementCount(String.valueOf(strings[1].charAt(i)), 1.0f); - Set v1 = counter.keySet(); - Set v2 = counter2.keySet(); - - - Set both = SetUtils.intersection(v1, v2); - - double sclar = 0, norm1 = 0, norm2 = 0; - for (String k : both) - sclar += counter.getCount(k) * counter2.getCount(k); - for (String k : v1) - norm1 += counter.getCount(k) * counter.getCount(k); - for (String k : v2) - norm2 += counter2.getCount(k) * counter2.getCount(k); - return sclar / Math.sqrt(norm1 * norm2); - } - - /** - * Returns the vector length (sqrt(sum(x_i)) - * @param vector the vector to return the vector length for - * @return the vector length of the passed in array - */ - public static double vectorLength(double[] vector) { - double ret = 0; - if (vector == null) - return ret; - else { - for (int i = 0; i < vector.length; i++) { - ret += Math.pow(vector[i], 2); - } - - } - return ret; - } - - /** - * Inverse document frequency: the total docs divided by the number of times the word - * appeared in a document - * @param totalDocs the total documents for the data applyTransformToDestination - * @param numTimesWordAppearedInADocument the number of times the word occurred in a document - * @return log(10) (totalDocs/numTImesWordAppearedInADocument) - */ - public static double idf(double totalDocs, double numTimesWordAppearedInADocument) { - //return totalDocs > 0 ? Math.log10(totalDocs/numTimesWordAppearedInADocument) : 0; - if (totalDocs == 0) - return 0; - double idf = Math.log10(totalDocs / numTimesWordAppearedInADocument); - return idf; - } - - /** - * Term frequency: 1+ log10(count) - * @param count the count of a word or character in a given string or document - * @return 1+ log(10) count - */ - public static double tf(int count, int documentLength) { - //return count > 0 ? 1 + Math.log10(count) : 0 - double tf = ((double) count / documentLength); - return tf; - } - - /** - * Return td * idf - * @param tf the term frequency (assumed calculated) - * @param idf inverse document frequency (assumed calculated) - * @return td * idf - */ - public static double tfidf(double tf, double idf) { - // System.out.println("TF-IDF Value: " + (tf * idf)); - return tf * idf; - } - - private static int charForLetter(char c) { - char[] chars = {'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', - 't', 'u', 'v', 'w', 'x', 'y', 'z'}; - for (int i = 0; i < chars.length; i++) - if (chars[i] == c) - return i; - return -1; - - } - - - - /** - * Total variance in target attribute - * @param residuals error - * @param targetAttribute data for target attribute - * @return Total variance in target attribute - */ - public static double ssTotal(double[] residuals, double[] targetAttribute) { - return ssReg(residuals, targetAttribute) + ssError(residuals, targetAttribute); - } - - /** - * This returns the sum of the given array. - * @param nums the array of numbers to sum - * @return the sum of the given array - */ - public static double sum(double[] nums) { - - double ret = 0; - for (double d : nums) - ret += d; - - return ret; - }//end sum - - /** - * This will merge the coordinates of the given coordinate system. - * @param x the x coordinates - * @param y the y coordinates - * @return a vector such that each (x,y) pair is at ret[i],ret[i+1] - */ - public static double[] mergeCoords(double[] x, double[] y) { - if (x.length != y.length) - throw new IllegalArgumentException( - "Sample sizes must be the same for each data applyTransformToDestination."); - double[] ret = new double[x.length + y.length]; - - for (int i = 0; i < x.length; i++) { - ret[i] = x[i]; - ret[i + 1] = y[i]; - } - return ret; - }//end mergeCoords - - /** - * This will merge the coordinates of the given coordinate system. - * @param x the x coordinates - * @param y the y coordinates - * @return a vector such that each (x,y) pair is at ret[i],ret[i+1] - */ - public static List mergeCoords(List x, List y) { - if (x.size() != y.size()) - throw new IllegalArgumentException( - "Sample sizes must be the same for each data applyTransformToDestination."); - - List ret = new ArrayList<>(); - - for (int i = 0; i < x.size(); i++) { - ret.add(x.get(i)); - ret.add(y.get(i)); - } - return ret; - }//end mergeCoords - - /** - * This returns the minimized loss values for a given vector. - * It is assumed that the x, y pairs are at - * vector[i], vector[i+1] - * @param vector the vector of numbers to getFromOrigin the weights for - * @return a double array with w_0 and w_1 are the associated indices. - */ - public static double[] weightsFor(List vector) { - /* split coordinate system */ - List coords = coordSplit(vector); - /* x vals */ - double[] x = coords.get(0); - /* y vals */ - double[] y = coords.get(1); - - - double meanX = sum(x) / x.length; - double meanY = sum(y) / y.length; - - double sumOfMeanDifferences = sumOfMeanDifferences(x, y); - double xDifferenceOfMean = sumOfMeanDifferencesOnePoint(x); - - double w_1 = sumOfMeanDifferences / xDifferenceOfMean; - - double w_0 = meanY - (w_1) * meanX; - - //double w_1=(n*sumOfProducts(x,y) - sum(x) * sum(y))/(n*sumOfSquares(x) - Math.pow(sum(x),2)); - - // double w_0=(sum(y) - (w_1 * sum(x)))/n; - - double[] ret = new double[vector.size()]; - ret[0] = w_0; - ret[1] = w_1; - - return ret; - }//end weightsFor - - /** - * This will return the squared loss of the given - * points - * @param x the x coordinates to use - * @param y the y coordinates to use - * @param w_0 the first weight - * - * @param w_1 the second weight - * @return the squared loss of the given points - */ - public static double squaredLoss(double[] x, double[] y, double w_0, double w_1) { - double sum = 0; - for (int j = 0; j < x.length; j++) { - sum += Math.pow((y[j] - (w_1 * x[j] + w_0)), 2); - } - return sum; - }//end squaredLoss - - - public static double w_1(double[] x, double[] y, int n) { - return (n * sumOfProducts(x, y) - sum(x) * sum(y)) / (n * sumOfSquares(x) - Math.pow(sum(x), 2)); - } - - public static double w_0(double[] x, double[] y, int n) { - double weight1 = w_1(x, y, n); - - return (sum(y) - (weight1 * sum(x))) / n; - } - - /** - * This returns the minimized loss values for a given vector. - * It is assumed that the x, y pairs are at - * vector[i], vector[i+1] - * @param vector the vector of numbers to getFromOrigin the weights for - * @return a double array with w_0 and w_1 are the associated indices. - */ - public static double[] weightsFor(double[] vector) { - - /* split coordinate system */ - List coords = coordSplit(vector); - /* x vals */ - double[] x = coords.get(0); - /* y vals */ - double[] y = coords.get(1); - - - double meanX = sum(x) / x.length; - double meanY = sum(y) / y.length; - - double sumOfMeanDifferences = sumOfMeanDifferences(x, y); - double xDifferenceOfMean = sumOfMeanDifferencesOnePoint(x); - - double w_1 = sumOfMeanDifferences / xDifferenceOfMean; - - double w_0 = meanY - (w_1) * meanX; - - - - double[] ret = new double[vector.length]; - ret[0] = w_0; - ret[1] = w_1; - - return ret; - }//end weightsFor - - public static double errorFor(double actual, double prediction) { - return actual - prediction; - } - - /** - * Used for calculating top part of simple regression for - * beta 1 - * @param vector the x coordinates - * @param vector2 the y coordinates - * @return the sum of mean differences for the input vectors - */ - public static double sumOfMeanDifferences(double[] vector, double[] vector2) { - double mean = sum(vector) / vector.length; - double mean2 = sum(vector2) / vector2.length; - double ret = 0; - for (int i = 0; i < vector.length; i++) { - double vec1Diff = vector[i] - mean; - double vec2Diff = vector2[i] - mean2; - ret += vec1Diff * vec2Diff; - } - return ret; - }//end sumOfMeanDifferences - - /** - * Used for calculating top part of simple regression for - * beta 1 - * @param vector the x coordinates - * @return the sum of mean differences for the input vectors - */ - public static double sumOfMeanDifferencesOnePoint(double[] vector) { - double mean = sum(vector) / vector.length; - double ret = 0; - for (int i = 0; i < vector.length; i++) { - double vec1Diff = Math.pow(vector[i] - mean, 2); - ret += vec1Diff; - } - return ret; - }//end sumOfMeanDifferences - - public static double variance(double[] vector) { - return sumOfMeanDifferencesOnePoint(vector) / vector.length; - } - - /** - * This returns the product of all numbers in the given array. - * @param nums the numbers to multiply over - * @return the product of all numbers in the array, or 0 - * if the length is or nums i null - */ - public static double times(double[] nums) { - if (nums == null || nums.length == 0) - return 0; - double ret = 1; - for (int i = 0; i < nums.length; i++) - ret *= nums[i]; - return ret; - }//end times - - - /** - * This returns the sum of products for the given - * numbers. - * @param nums the sum of products for the give numbers - * @return the sum of products for the given numbers - */ - public static double sumOfProducts(double[]... nums) { - if (nums == null || nums.length < 1) - return 0; - double sum = 0; - - for (int i = 0; i < nums.length; i++) { - /* The ith column for all of the rows */ - double[] column = column(i, nums); - sum += times(column); - - } - return sum; - }//end sumOfProducts - - - /** - * This returns the given column over an n arrays - * @param column the column to getFromOrigin values for - * @param nums the arrays to extract values from - * @return a double array containing all of the numbers in that column - * for all of the arrays. - * @throws IllegalArgumentException if the index is < 0 - */ - private static double[] column(int column, double[]... nums) throws IllegalArgumentException { - - double[] ret = new double[nums.length]; - - for (int i = 0; i < nums.length; i++) { - double[] curr = nums[i]; - ret[i] = curr[column]; - } - return ret; - }//end column - - /** - * This returns the coordinate split in a list of coordinates - * such that the values for ret[0] are the x values - * and ret[1] are the y values - * @param vector the vector to split with x and y values/ - * @return a coordinate split for the given vector of values. - * if null, is passed in null is returned - */ - public static List coordSplit(double[] vector) { - - if (vector == null) - return null; - List ret = new ArrayList<>(); - /* x coordinates */ - double[] xVals = new double[vector.length / 2]; - /* y coordinates */ - double[] yVals = new double[vector.length / 2]; - /* current points */ - int xTracker = 0; - int yTracker = 0; - for (int i = 0; i < vector.length; i++) { - //even value, x coordinate - if (i % 2 == 0) - xVals[xTracker++] = vector[i]; - //y coordinate - else - yVals[yTracker++] = vector[i]; - } - ret.add(xVals); - ret.add(yVals); - - return ret; - }//end coordSplit - - - /** - * This returns the coordinate split in a list of coordinates - * such that the values for ret[0] are the x values - * and ret[1] are the y values - * @param vector the vector to split with x and y values - * Note that the list will be more stable due to the size operator. - * The array version will have extraneous values if not monitored - * properly. - * @return a coordinate split for the given vector of values. - * if null, is passed in null is returned - */ - public static List coordSplit(List vector) { - - if (vector == null) - return null; - List ret = new ArrayList<>(); - /* x coordinates */ - double[] xVals = new double[vector.size() / 2]; - /* y coordinates */ - double[] yVals = new double[vector.size() / 2]; - /* current points */ - int xTracker = 0; - int yTracker = 0; - for (int i = 0; i < vector.size(); i++) { - //even value, x coordinate - if (i % 2 == 0) - xVals[xTracker++] = vector.get(i); - //y coordinate - else - yVals[yTracker++] = vector.get(i); - } - ret.add(xVals); - ret.add(yVals); - - return ret; - }//end coordSplit - - - - /** - * This returns the x values of the given vector. - * These are assumed to be the even values of the vector. - * @param vector the vector to getFromOrigin the values for - * @return the x values of the given vector - */ - public static double[] xVals(double[] vector) { - - - if (vector == null) - return null; - double[] x = new double[vector.length / 2]; - int count = 0; - for (int i = 0; i < vector.length; i++) { - if (i % 2 != 0) - x[count++] = vector[i]; - } - return x; - }//end xVals - - /** - * This returns the odd indexed values for the given vector - * @param vector the odd indexed values of rht egiven vector - * @return the y values of the given vector - */ - public static double[] yVals(double[] vector) { - double[] y = new double[vector.length / 2]; - int count = 0; - for (int i = 0; i < vector.length; i++) { - if (i % 2 == 0) - y[count++] = vector[i]; - } - return y; - }//end yVals - - - /** - * This returns the sum of squares for the given vector. - * - * @param vector the vector to obtain the sum of squares for - * @return the sum of squares for this vector - */ - public static double sumOfSquares(double[] vector) { - double ret = 0; - for (double d : vector) - ret += Math.pow(d, 2); - return ret; - } - - /** - * This returns the determination coefficient of two vectors given a length - * @param y1 the first vector - * @param y2 the second vector - * @param n the length of both vectors - * @return the determination coefficient or r^2 - */ - public static double determinationCoefficient(double[] y1, double[] y2, int n) { - return Math.pow(correlation(y1, y2), 2); - } - - - - /** - * Returns the logarithm of a for base 2. - * - * @param a a double - * @return the logarithm for base 2 - */ - public static double log2(double a) { - if (a == 0) - return 0.0; - return Math.log(a) / log2; - } - - /** - * This returns the slope of the given points. - * @param x1 the first x to use - * @param x2 the end x to use - * @param y1 the begin y to use - * @param y2 the end y to use - * @return the slope of the given points - */ - public double slope(double x1, double x2, double y1, double y2) { - return (y2 - y1) / (x2 - x1); - }//end slope - - /** - * This returns the root mean squared error of two data sets - * @param real the real values - * @param predicted the predicted values - * @return the root means squared error for two data sets - */ - public static double rootMeansSquaredError(double[] real, double[] predicted) { - double ret = 0.0; - for (int i = 0; i < real.length; i++) { - ret += Math.pow((real[i] - predicted[i]), 2); - } - return Math.sqrt(ret / real.length); - }//end rootMeansSquaredError - - /** - * This returns the entropy (information gain, or uncertainty of a random variable). - * @param vector the vector of values to getFromOrigin the entropy for - * @return the entropy of the given vector - */ - public static double entropy(double[] vector) { - if (vector == null || vector.length < 1) - return 0; - else { - double ret = 0; - for (double d : vector) - ret += d * Math.log(d); - return ret; - - } - }//end entropy - - /** - * This returns the kronecker delta of two doubles. - * @param i the first number to compare - * @param j the second number to compare - * @return 1 if they are equal, 0 otherwise - */ - public static int kroneckerDelta(double i, double j) { - return (i == j) ? 1 : 0; - } - - /** - * This calculates the adjusted r^2 including degrees of freedom. - * Also known as calculating "strength" of a regression - * @param rSquared the r squared value to calculate - * @param numRegressors number of variables - * @param numDataPoints size of the data applyTransformToDestination - * @return an adjusted r^2 for degrees of freedom - */ - public static double adjustedrSquared(double rSquared, int numRegressors, int numDataPoints) { - double divide = (numDataPoints - 1.0) / (numDataPoints - numRegressors - 1.0); - double rSquaredDiff = 1 - rSquared; - return 1 - (rSquaredDiff * divide); - } - - - public static double[] normalizeToOne(double[] doubles) { - normalize(doubles, sum(doubles)); - return doubles; - } - - public static double min(double[] doubles) { - double ret = doubles[0]; - for (double d : doubles) - if (d < ret) - ret = d; - return ret; - } - - public static double max(double[] doubles) { - double ret = doubles[0]; - for (double d : doubles) - if (d > ret) - ret = d; - return ret; - } - - /** - * Normalizes the doubles in the array using the given value. - * - * @param doubles the array of double - * @param sum the value by which the doubles are to be normalized - * @exception IllegalArgumentException if sum is zero or NaN - */ - public static void normalize(double[] doubles, double sum) { - - if (Double.isNaN(sum)) { - throw new IllegalArgumentException("Can't normalize array. Sum is NaN."); - } - if (sum == 0) { - // Maybe this should just be a return. - throw new IllegalArgumentException("Can't normalize array. Sum is zero."); - } - for (int i = 0; i < doubles.length; i++) { - doubles[i] /= sum; - } - }//end normalize - - /** - * Converts an array containing the natural logarithms of - * probabilities stored in a vector back into probabilities. - * The probabilities are assumed to sum to one. - * - * @param a an array holding the natural logarithms of the probabilities - * @return the converted array - */ - public static double[] logs2probs(double[] a) { - - double max = a[maxIndex(a)]; - double sum = 0.0; - - double[] result = new double[a.length]; - for (int i = 0; i < a.length; i++) { - result[i] = Math.exp(a[i] - max); - sum += result[i]; - } - - normalize(result, sum); - - return result; - }//end logs2probs - - /** - * This returns the entropy for a given vector of probabilities. - * @param probabilities the probabilities to getFromOrigin the entropy for - * @return the entropy of the given probabilities. - */ - public static double information(double[] probabilities) { - double total = 0.0; - for (double d : probabilities) { - total += (-1.0 * log2(d) * d); - } - return total; - }//end information - - /** - * - * - * Returns index of maximum element in a given - * array of doubles. First maximum is returned. - * - * @param doubles the array of doubles - * @return the index of the maximum element - */ - public static /*@pure@*/ int maxIndex(double[] doubles) { - - double maximum = 0; - int maxIndex = 0; - - for (int i = 0; i < doubles.length; i++) { - if ((i == 0) || (doubles[i] > maximum)) { - maxIndex = i; - maximum = doubles[i]; - } - } - - return maxIndex; - }//end maxIndex - - /** - * This will return the factorial of the given number n. - * @param n the number to getFromOrigin the factorial for - * @return the factorial for this number - */ - public static double factorial(double n) { - if (n == 1 || n == 0) - return 1; - for (double i = n; i > 0; i--, n *= (i > 0 ? i : 1)) { - } - return n; - }//end factorial - - - - /** The small deviation allowed in double comparisons. */ - public static double SMALL = 1e-6; - - /** - * Returns the log-odds for a given probability. - * - * @param prob the probability - * - * @return the log-odds after the probability has been mapped to - * [Utils.SMALL, 1-Utils.SMALL] - */ - public static /*@pure@*/ double probToLogOdds(double prob) { - - if (gr(prob, 1) || (sm(prob, 0))) { - throw new IllegalArgumentException("probToLogOdds: probability must " + "be in [0,1] " + prob); - } - double p = SMALL + (1.0 - 2 * SMALL) * prob; - return Math.log(p / (1 - p)); - } - - /** - * Rounds a double to the next nearest integer value. The JDK version - * of it doesn't work properly. - * - * @param value the double value - * @return the resulting integer value - */ - public static /*@pure@*/ int round(double value) { - - return value > 0 ? (int) (value + 0.5) : -(int) (Math.abs(value) + 0.5); - }//end round - - /** - * This returns the permutation of n choose r. - * @param n the n to choose - * @param r the number of elements to choose - * @return the permutation of these numbers - */ - public static double permutation(double n, double r) { - double nFac = MathUtils.factorial(n); - double nMinusRFac = MathUtils.factorial((n - r)); - return nFac / nMinusRFac; - }//end permutation - - - /** - * This returns the combination of n choose r - * @param n the number of elements overall - * @param r the number of elements to choose - * @return the amount of possible combinations for this applyTransformToDestination of elements - */ - public static double combination(double n, double r) { - double nFac = MathUtils.factorial(n); - double rFac = MathUtils.factorial(r); - double nMinusRFac = MathUtils.factorial((n - r)); - - return nFac / (rFac * nMinusRFac); - }//end combination - - - /** - * sqrt(a^2 + b^2) without under/overflow. - */ - public static double hypotenuse(double a, double b) { - double r; - if (Math.abs(a) > Math.abs(b)) { - r = b / a; - r = Math.abs(a) * Math.sqrt(1 + r * r); - } else if (b != 0) { - r = a / b; - r = Math.abs(b) * Math.sqrt(1 + r * r); - } else { - r = 0.0; - } - return r; - }//end hypotenuse - - /** - * Rounds a double to the next nearest integer value in a probabilistic - * fashion (e.g. 0.8 has a 20% chance of being rounded down to 0 and a - * 80% chance of being rounded up to 1). In the limit, the average of - * the rounded numbers generated by this procedure should converge to - * the original double. - * - * @param value the double value - * @param rand the random number generator - * @return the resulting integer value - */ - public static int probRound(double value, Random rand) { - - if (value >= 0) { - double lower = Math.floor(value); - double prob = value - lower; - if (rand.nextDouble() < prob) { - return (int) lower + 1; - } else { - return (int) lower; - } - } else { - double lower = Math.floor(Math.abs(value)); - double prob = Math.abs(value) - lower; - if (rand.nextDouble() < prob) { - return -((int) lower + 1); - } else { - return -(int) lower; - } - } - }//end probRound - - /** - * Rounds a double to the given number of decimal places. - * - * @param value the double value - * @param afterDecimalPoint the number of digits after the decimal point - * @return the double rounded to the given precision - */ - public static /*@pure@*/ double roundDouble(double value, int afterDecimalPoint) { - - double mask = Math.pow(10.0, (double) afterDecimalPoint); - - return (double) (Math.round(value * mask)) / mask; - }//end roundDouble - - - - /** - * Rounds a double to the given number of decimal places. - * - * @param value the double value - * @param afterDecimalPoint the number of digits after the decimal point - * @return the double rounded to the given precision - */ - public static /*@pure@*/ float roundFloat(float value, int afterDecimalPoint) { - - float mask = (float) Math.pow(10, (float) afterDecimalPoint); - - return (float) (Math.round(value * mask)) / mask; - }//end roundDouble - - /** - * This will return the bernoulli trial for the given event. - * A bernoulli trial is a mechanism for detecting the probability - * of a given event occurring k times in n independent trials - * @param n the number of trials - * @param k the number of times the target event occurs - * @param successProb the probability of the event happening - * @return the probability of the given event occurring k times. - */ - public static double bernoullis(double n, double k, double successProb) { - - double combo = MathUtils.combination(n, k); - double q = 1 - successProb; - return combo * Math.pow(successProb, k) * Math.pow(q, n - k); - }//end bernoullis - - /** - * Tests if a is smaller than b. - * - * @param a a double - * @param b a double - */ - public static /*@pure@*/ boolean sm(double a, double b) { - - return (b - a > SMALL); - } - - /** - * Tests if a is greater than b. - * - * @param a a double - * @param b a double - */ - public static /*@pure@*/ boolean gr(double a, double b) { - - return (a - b > SMALL); - } - - /** - * This will take a given string and separator and convert it to an equivalent - * double array. - * @param data the data to separate - * @param separator the separator to use - * @return the new double array based on the given data - */ - public static double[] fromString(String data, String separator) { - String[] split = data.split(separator); - double[] ret = new double[split.length]; - for (int i = 0; i < split.length; i++) { - ret[i] = Double.parseDouble(split[i]); - } - return ret; - }//end fromString - - /** - * Computes the mean for an array of doubles. - * - * @param vector the array - * @return the mean - */ - public static /*@pure@*/ double mean(double[] vector) { - - double sum = 0; - - if (vector.length == 0) { - return 0; - } - for (int i = 0; i < vector.length; i++) { - sum += vector[i]; - } - return sum / (double) vector.length; - }//end mean - - /** - * This will return the cholesky decomposition of - * the given matrix - * @param m the matrix to convert - * @return the cholesky decomposition of the given - * matrix. - * See: - * http://en.wikipedia.org/wiki/Cholesky_decomposition - * @throws NonSquareMatrixException - */ - public CholeskyDecomposition choleskyFromMatrix(RealMatrix m) throws Exception { - return new CholeskyDecomposition(m); - }//end choleskyFromMatrix - - - - /** - * This will convert the given binary string to a decimal based - * integer - * @param binary the binary string to convert - * @return an equivalent base 10 number - */ - public static int toDecimal(String binary) { - long num = Long.parseLong(binary); - long rem; - /* Use the remainder method to ensure validity */ - while (num > 0) { - rem = num % 10; - num = num / 10; - if (rem != 0 && rem != 1) { - System.out.println("This is not a binary number."); - System.out.println("Please try once again."); - return -1; - } - } - return Integer.parseInt(binary, 2); - }//end toDecimal - - - /** - * This will translate a vector in to an equivalent integer - * @param vector the vector to translate - * @return a z value such that the value is the interleaved lsd to msd for each - * double in the vector - */ - public static int distanceFinderZValue(double[] vector) { - StringBuilder binaryBuffer = new StringBuilder(); - List binaryReps = new ArrayList<>(vector.length); - for (int i = 0; i < vector.length; i++) { - double d = vector[i]; - int j = (int) d; - String binary = Integer.toBinaryString(j); - binaryReps.add(binary); - } - //append from left to right, the least to the most significant bit - //till all strings are empty - while (!binaryReps.isEmpty()) { - for (int j = 0; j < binaryReps.size(); j++) { - String curr = binaryReps.get(j); - if (!curr.isEmpty()) { - char first = curr.charAt(0); - binaryBuffer.append(first); - curr = curr.substring(1); - binaryReps.set(j, curr); - } else - binaryReps.remove(j); - } - } - return Integer.parseInt(binaryBuffer.toString(), 2); - - }//end distanceFinderZValue - - /** - * This returns the distance of two vectors - * sum(i=1,n) (q_i - p_i)^2 - * @param p the first vector - * @param q the second vector - * @return the distance between two vectors - */ - public static double euclideanDistance(double[] p, double[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double diff = (q[i] - p[i]); - double sq = Math.pow(diff, 2); - ret += sq; - } - return ret; - - }//end euclideanDistance - - /** - * This returns the distance of two vectors - * sum(i=1,n) (q_i - p_i)^2 - * @param p the first vector - * @param q the second vector - * @return the distance between two vectors - */ - public static double euclideanDistance(float[] p, float[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double diff = (q[i] - p[i]); - double sq = Math.pow(diff, 2); - ret += sq; - } - return ret; - - }//end euclideanDistance - - /** - * This will generate a series of uniformally distributed - * numbers between l times - * @param l the number of numbers to generate - * @return l uniformally generated numbers - */ - public static double[] generateUniform(int l) { - double[] ret = new double[l]; - Random rgen = new Random(); - for (int i = 0; i < l; i++) { - ret[i] = rgen.nextDouble(); - } - return ret; - }//end generateUniform - - - /** - * This will calculate the Manhattan distance between two sets of points. - * The Manhattan distance is equivalent to: - * 1_sum_n |p_i - q_i| - * @param p the first point vector - * @param q the second point vector - * @return the Manhattan distance between two object - */ - public static double manhattanDistance(double[] p, double[] q) { - - double ret = 0; - for (int i = 0; i < p.length; i++) { - double difference = p[i] - q[i]; - ret += Math.abs(difference); - } - return ret; - }//end manhattanDistance - - - - public static double[] sampleDoublesInInterval(double[][] doubles, int l) { - double[] sample = new double[l]; - for (int i = 0; i < l; i++) { - int rand1 = randomNumberBetween(0, doubles.length - 1); - int rand2 = randomNumberBetween(0, doubles[i].length); - sample[i] = doubles[rand1][rand2]; - } - - return sample; - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (Math.random() * ((end - begin) + 1)); - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end, RandomGenerator rng) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (rng.nextDouble() * ((end - begin) + 1)); - } - - /** - * Generates a random integer between the specified numbers - * @param begin the begin of the interval - * @param end the end of the interval - * @return an int between begin and end - */ - public static int randomNumberBetween(double begin, double end, org.nd4j.linalg.api.rng.Random rng) { - if (begin > end) - throw new IllegalArgumentException("Begin must not be less than end"); - return (int) begin + (int) (rng.nextDouble() * ((end - begin) + 1)); - } - - /** - * - * @param begin - * @param end - * @return - */ - public static float randomFloatBetween(float begin, float end) { - float rand = (float) Math.random(); - return begin + (rand * ((end - begin))); - } - - public static double randomDoubleBetween(double begin, double end) { - return begin + (Math.random() * ((end - begin))); - } - - public static void shuffleArray(int[] array, long rngSeed) { - shuffleArray(array, new Random(rngSeed)); - } - - public static void shuffleArray(int[] array, Random rng) { - //https://en.wikipedia.org/wiki/Fisher%E2%80%93Yates_shuffle#The_modern_algorithm - for (int i = array.length - 1; i > 0; i--) { - int j = rng.nextInt(i + 1); - int temp = array[j]; - array[j] = array[i]; - array[i] = temp; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java deleted file mode 100644 index c147c474e..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/MultiThreadUtils.java +++ /dev/null @@ -1,74 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.List; -import java.util.concurrent.*; - -public class MultiThreadUtils { - - private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); - - private static ExecutorService instance; - - private MultiThreadUtils() {} - - public static synchronized ExecutorService newExecutorService() { - int nThreads = Runtime.getRuntime().availableProcessors(); - return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue(), - new ThreadFactory() { - @Override - public Thread newThread(Runnable r) { - Thread t = Executors.defaultThreadFactory().newThread(r); - t.setDaemon(true); - return t; - } - }); - } - - public static void parallelTasks(final List tasks, ExecutorService executorService) { - int tasksCount = tasks.size(); - final CountDownLatch latch = new CountDownLatch(tasksCount); - for (int i = 0; i < tasksCount; i++) { - final int taskIdx = i; - executorService.execute(new Runnable() { - public void run() { - try { - tasks.get(taskIdx).run(); - } catch (Throwable e) { - log.info("Unchecked exception thrown by task", e); - } finally { - latch.countDown(); - } - } - }); - } - - try { - latch.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java deleted file mode 100755 index eecf576d0..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/util/SetUtils.java +++ /dev/null @@ -1,61 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.util; - -import java.util.Collection; -import java.util.HashSet; -import java.util.Set; - -public class SetUtils { - private SetUtils() {} - - // Set specific operations - - public static Set intersection(Collection parentCollection, Collection removeFromCollection) { - Set results = new HashSet<>(parentCollection); - results.retainAll(removeFromCollection); - return results; - } - - public static boolean intersectionP(Set s1, Set s2) { - for (T elt : s1) { - if (s2.contains(elt)) - return true; - } - return false; - } - - public static Set union(Set s1, Set s2) { - Set s3 = new HashSet<>(s1); - s3.addAll(s2); - return s3; - } - - /** Return is s1 \ s2 */ - - public static Set difference(Collection s1, Collection s2) { - Set s3 = new HashSet<>(s1); - s3.removeAll(s2); - return s3; - } -} - - diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java deleted file mode 100644 index e4f699289..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTree.java +++ /dev/null @@ -1,633 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.*; -import lombok.extern.slf4j.Slf4j; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.deeplearning4j.clustering.sptree.HeapObject; -import org.deeplearning4j.clustering.util.MathUtils; -import org.nd4j.linalg.api.memory.MemoryWorkspace; -import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; -import org.nd4j.linalg.api.memory.enums.*; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.reduce3.*; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.Serializable; -import java.util.*; -import java.util.concurrent.*; -import java.util.concurrent.atomic.AtomicInteger; - -@Slf4j -@Builder -@AllArgsConstructor -public class VPTree implements Serializable { - private static final long serialVersionUID = 1L; - - public static final String EUCLIDEAN = "euclidean"; - private double tau; - @Getter - @Setter - private INDArray items; - private List itemsList; - private Node root; - private String similarityFunction; - @Getter - private boolean invert = false; - private transient ExecutorService executorService; - @Getter - private int workers = 1; - private AtomicInteger size = new AtomicInteger(0); - - private transient ThreadLocal scalars = new ThreadLocal<>(); - - private WorkspaceConfiguration workspaceConfiguration; - - protected VPTree() { - // method for serialization only - scalars = new ThreadLocal<>(); - } - - /** - * - * @param points - * @param invert - */ - public VPTree(INDArray points, boolean invert) { - this(points, "euclidean", 1, invert); - } - - /** - * - * @param points - * @param invert - * @param workers number of parallel workers for tree building (increases memory requirements!) - */ - public VPTree(INDArray points, boolean invert, int workers) { - this(points, "euclidean", workers, invert); - } - - /** - * - * @param items the items to use - * @param similarityFunction the similarity function to use - * @param invert whether to invert the distance (similarity functions have different min/max objectives) - */ - public VPTree(INDArray items, String similarityFunction, boolean invert) { - this.similarityFunction = similarityFunction; - this.invert = invert; - this.items = items; - root = buildFromPoints(items); - workers = 1; - } - - /** - * - * @param items the items to use - * @param similarityFunction the similarity function to use - * @param workers number of parallel workers for tree building (increases memory requirements!) - * @param invert whether to invert the metric (different optimization objective) - */ - public VPTree(List items, String similarityFunction, int workers, boolean invert) { - this.workers = workers; - - val list = new INDArray[items.size()]; - - // build list of INDArrays first - for (int i = 0; i < items.size(); i++) - list[i] = items.get(i).getPoint(); - //this.items.putRow(i, items.get(i).getPoint()); - - // just stack them out with concat :) - this.items = Nd4j.pile(list); - - this.invert = invert; - this.similarityFunction = similarityFunction; - root = buildFromPoints(this.items); - } - - - - /** - * - * @param items - * @param similarityFunction - */ - public VPTree(INDArray items, String similarityFunction) { - this(items, similarityFunction, 1, false); - } - - /** - * - * @param items - * @param similarityFunction - * @param workers number of parallel workers for tree building (increases memory requirements!) - * @param invert - */ - public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) { - this.similarityFunction = similarityFunction; - this.invert = invert; - this.items = items; - - this.workers = workers; - root = buildFromPoints(items); - } - - - /** - * - * @param items - * @param similarityFunction - */ - public VPTree(List items, String similarityFunction) { - this(items, similarityFunction, 1, false); - } - - - /** - * - * @param items - */ - public VPTree(INDArray items) { - this(items, EUCLIDEAN); - } - - - /** - * - * @param items - */ - public VPTree(List items) { - this(items, EUCLIDEAN); - } - - /** - * Create an ndarray - * from the datapoints - * @param data - * @return - */ - public static INDArray buildFromData(List data) { - INDArray ret = Nd4j.create(data.size(), data.get(0).getD()); - for (int i = 0; i < ret.slices(); i++) - ret.putSlice(i, data.get(i).getPoint()); - return ret; - } - - - - /** - * - * @param basePoint - * @param distancesArr - */ - public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) { - switch (similarityFunction) { - case "euclidean": - Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1)); - break; - case "cosinedistance": - Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1)); - break; - case "cosinesimilarity": - Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1)); - break; - case "manhattan": - Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1)); - break; - case "dot": - Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1)); - break; - case "jaccard": - Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1)); - break; - case "hamming": - Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1)); - break; - default: - Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1)); - break; - - } - - if (invert) - distancesArr.negi(); - - } - - public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) { - calcDistancesRelativeTo(items, basePoint, distancesArr); - } - - - /** - * Euclidean distance - * @return the distance between the two points - */ - public double distance(INDArray arr1, INDArray arr2) { - if (scalars == null) - scalars = new ThreadLocal<>(); - - if (scalars.get() == null) - scalars.set(Nd4j.scalar(arr1.dataType(), 0.0)); - - switch (similarityFunction) { - case "jaccard": - double ret7 = Nd4j.getExecutioner() - .execAndReturn(new JaccardDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret7 : ret7; - case "hamming": - double ret8 = Nd4j.getExecutioner() - .execAndReturn(new HammingDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret8 : ret8; - case "euclidean": - double ret = Nd4j.getExecutioner() - .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret : ret; - case "cosinesimilarity": - double ret2 = Nd4j.getExecutioner() - .execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret2 : ret2; - case "cosinedistance": - double ret6 = Nd4j.getExecutioner() - .execAndReturn(new CosineDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret6 : ret6; - case "manhattan": - double ret3 = Nd4j.getExecutioner() - .execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret3 : ret3; - case "dot": - double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2); - return invert ? -dotRet : dotRet; - default: - double ret4 = Nd4j.getExecutioner() - .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) - .getFinalResult().doubleValue(); - return invert ? -ret4 : ret4; - - } - } - - protected class NodeBuilder implements Callable { - protected List list; - protected List indices; - - public NodeBuilder(List list, List indices) { - this.list = list; - this.indices = indices; - } - - @Override - public Node call() throws Exception { - return buildFromPoints(list, indices); - } - } - - private Node buildFromPoints(List points, List indices) { - Node ret = new Node(0, 0); - - - // nothing to sort here - if (points.size() == 1) { - ret.point = points.get(0); - ret.index = indices.get(0); - return ret; - } - - // opening workspace, and creating it if that's the first call - /* MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ - - INDArray items = Nd4j.vstack(points); - int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); - INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint); - ret.point = basePoint; - ret.index = indices.get(randomPoint); - INDArray distancesArr = Nd4j.create(items.rows(), 1); - - calcDistancesRelativeTo(items, basePoint, distancesArr); - - double medianDistance = distancesArr.medianNumber().doubleValue(); - - ret.threshold = (float) medianDistance; - - List leftPoints = new ArrayList<>(); - List leftIndices = new ArrayList<>(); - List rightPoints = new ArrayList<>(); - List rightIndices = new ArrayList<>(); - - for (int i = 0; i < distancesArr.length(); i++) { - if (i == randomPoint) - continue; - - if (distancesArr.getDouble(i) < medianDistance) { - leftPoints.add(points.get(i)); - leftIndices.add(indices.get(i)); - } else { - rightPoints.add(points.get(i)); - rightIndices.add(indices.get(i)); - } - } - - // closing workspace - //workspace.notifyScopeLeft(); - //log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes()); - - if (workers > 1) { - if (!leftPoints.isEmpty()) - ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints); - - if (!rightPoints.isEmpty()) - ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices)); - } else { - if (!leftPoints.isEmpty()) - ret.left = buildFromPoints(leftPoints, leftIndices); - - if (!rightPoints.isEmpty()) - ret.right = buildFromPoints(rightPoints, rightIndices); - } - - return ret; - } - - private Node buildFromPoints(INDArray items) { - if (executorService == null && items == this.items && workers > 1) { - final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); - - executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { - @Override - public Thread newThread(final Runnable r) { - Thread t = new Thread(new Runnable() { - - @Override - public void run() { - Nd4j.getAffinityManager().unsafeSetDevice(deviceId); - r.run(); - } - }); - - t.setDaemon(true); - t.setName("VPTree thread"); - - return t; - } - }); - } - - - final Node ret = new Node(0, 0); - size.incrementAndGet(); - - /*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) - .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) - .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) - .policySpill(SpillPolicy.REALLOCATE).build(); - - // opening workspace - MemoryWorkspace workspace = - Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ - - int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); - INDArray basePoint = items.getRow(randomPoint, true); - INDArray distancesArr = Nd4j.create(items.rows(), 1); - ret.point = basePoint; - ret.index = randomPoint; - - calcDistancesRelativeTo(items, basePoint, distancesArr); - - double medianDistance = distancesArr.medianNumber().doubleValue(); - - ret.threshold = (float) medianDistance; - - List leftPoints = new ArrayList<>(); - List leftIndices = new ArrayList<>(); - List rightPoints = new ArrayList<>(); - List rightIndices = new ArrayList<>(); - - for (int i = 0; i < distancesArr.length(); i++) { - if (i == randomPoint) - continue; - - if (distancesArr.getDouble(i) < medianDistance) { - leftPoints.add(items.getRow(i, true)); - leftIndices.add(i); - } else { - rightPoints.add(items.getRow(i, true)); - rightIndices.add(i); - } - } - - // closing workspace - //workspace.notifyScopeLeft(); - //workspace.destroyWorkspace(true); - - if (!leftPoints.isEmpty()) - ret.left = buildFromPoints(leftPoints, leftIndices); - - if (!rightPoints.isEmpty()) - ret.right = buildFromPoints(rightPoints, rightIndices); - - // destroy once again - //workspace.destroyWorkspace(true); - - if (ret.left != null) - ret.left.fetchFutures(); - - if (ret.right != null) - ret.right.fetchFutures(); - - if (executorService != null) - executorService.shutdown(); - - return ret; - } - - public void search(@NonNull INDArray target, int k, List results, List distances) { - search(target, k, results, distances, true); - } - - public void search(@NonNull INDArray target, int k, List results, List distances, - boolean filterEqual) { - search(target, k, results, distances, filterEqual, false); - } - /** - * - * @param target - * @param k - * @param results - * @param distances - */ - public void search(@NonNull INDArray target, int k, List results, List distances, - boolean filterEqual, boolean dropEdge) { - if (items != null) - if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1) - throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " - + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); - - k = Math.min(k, items.rows()); - results.clear(); - distances.clear(); - - PriorityQueue pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); - - search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE); - - while (!pq.isEmpty()) { - HeapObject ho = pq.peek(); - results.add(new DataPoint(ho.getIndex(), ho.getPoint())); - distances.add(ho.getDistance()); - pq.poll(); - } - - Collections.reverse(results); - Collections.reverse(distances); - - if (dropEdge || results.size() > k) { - if (filterEqual && distances.get(0) == 0.0) { - results.remove(0); - distances.remove(0); - } - - while (results.size() > k) { - results.remove(results.size() - 1); - distances.remove(distances.size() - 1); - } - } - } - - /** - * - * @param node - * @param target - * @param k - * @param pq - */ - public void search(Node node, INDArray target, int k, PriorityQueue pq, double cTau) { - - if (node == null) - return; - - double tau = cTau; - - INDArray get = node.getPoint(); //items.getRow(node.getIndex()); - double distance = distance(get, target); - if (distance < tau) { - if (pq.size() == k) - pq.poll(); - - pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance)); - if (pq.size() == k) - tau = pq.peek().getDistance(); - } - - Node left = node.getLeft(); - Node right = node.getRight(); - - if (left == null && right == null) - return; - - if (distance < node.getThreshold()) { - if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first - search(left, target, k, pq, tau); - } - - if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child - search(right, target, k, pq, tau); - } - - } else { - if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first - search(right, target, k, pq, tau); - } - - if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child - search(left, target, k, pq, tau); - } - } - - } - - - protected class HeapObjectComparator implements Comparator { - - @Override - public int compare(HeapObject o1, HeapObject o2) { - return Double.compare(o2.getDistance(), o1.getDistance()); - } - } - - @Data - public static class Node implements Serializable { - private static final long serialVersionUID = 2L; - - private int index; - private float threshold; - private Node left, right; - private INDArray point; - protected transient Future futureLeft; - protected transient Future futureRight; - - public Node(int index, float threshold) { - this.index = index; - this.threshold = threshold; - } - - - public void fetchFutures() { - try { - if (futureLeft != null) { - /*while (!futureLeft.isDone()) - Thread.sleep(100);*/ - - - left = futureLeft.get(); - } - - if (futureRight != null) { - /*while (!futureRight.isDone()) - Thread.sleep(100);*/ - - right = futureRight.get(); - } - - - if (left != null) - left.fetchFutures(); - - if (right != null) - right.fetchFutures(); - } catch (Exception e) { - throw new RuntimeException(e); - } - - - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java deleted file mode 100644 index 2cf87d69b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/VPTreeFillSearch.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.Getter; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -public class VPTreeFillSearch { - private VPTree vpTree; - private int k; - @Getter - private List results; - @Getter - private List distances; - private INDArray target; - - public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) { - this.vpTree = vpTree; - this.k = k; - this.target = target; - } - - public void search() { - results = new ArrayList<>(); - distances = new ArrayList<>(); - //initial search - //vpTree.search(target,k,results,distances); - - //fill till there is k results - //by going down the list - // if(results.size() < k) { - INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); - vpTree.calcDistancesRelativeTo(target, distancesArr); - INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); - results.clear(); - distances.clear(); - if (vpTree.getItems().isVector()) { - for (int i = 0; i < k; i++) { - int idx = sortWithIndices[0].getInt(i); - results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); - distances.add(sortWithIndices[1].getDouble(idx)); - } - } else { - for (int i = 0; i < k; i++) { - int idx = sortWithIndices[0].getInt(i); - results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); - //distances.add(sortWithIndices[1].getDouble(idx)); - distances.add(sortWithIndices[1].getDouble(i)); - } - } - - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java deleted file mode 100644 index 49d19a719..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/vptree/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java deleted file mode 100644 index 5a83fa85b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/cluster/ClusterSetTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.cluster; - -import org.junit.Assert; -import org.junit.Test; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.ArrayList; -import java.util.List; - -public class ClusterSetTest { - @Test - public void testGetMostPopulatedClusters() { - ClusterSet clusterSet = new ClusterSet(false); - List clusters = new ArrayList<>(); - for (int i = 0; i < 5; i++) { - Cluster cluster = new Cluster(); - cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); - clusters.add(cluster); - } - clusterSet.setClusters(clusters); - List mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); - for (int i = 0; i < 5; i++) { - Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java deleted file mode 100644 index e436d62f5..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kdtree/KDTreeTest.java +++ /dev/null @@ -1,422 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kdtree; - -import lombok.val; -import org.deeplearning4j.BaseDL4JTest; -import org.joda.time.Duration; -import org.junit.Before; -import org.junit.BeforeClass; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Pair; -import org.nd4j.shade.guava.base.Stopwatch; -import org.nd4j.shade.guava.primitives.Doubles; -import org.nd4j.shade.guava.primitives.Floats; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Random; - -import static java.util.concurrent.TimeUnit.MILLISECONDS; -import static java.util.concurrent.TimeUnit.SECONDS; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class KDTreeTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - - private KDTree kdTree; - - @BeforeClass - public static void beforeClass(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @Before - public void setUp() { - kdTree = new KDTree(2); - float[] data = new float[]{7,2}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{5,4}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{2,3}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{4,7}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{9,6}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{8,1}; - kdTree.insert(Nd4j.createFromArray(data)); - } - - @Test - public void testTree() { - KDTree tree = new KDTree(2); - INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT); - INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT); - tree.insert(half); - tree.insert(one); - Pair pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT)); - assertEquals(half, pair.getValue()); - } - - @Test - public void testInsert() { - int elements = 10; - List digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); - - KDTree kdTree = new KDTree(digits.size()); - List> lists = new ArrayList<>(); - for (int i = 0; i < elements; i++) { - List thisList = new ArrayList<>(digits.size()); - for (int k = 0; k < digits.size(); k++) { - thisList.add(digits.get(k) + i); - } - lists.add(thisList); - } - - for (int i = 0; i < elements; i++) { - double[] features = Doubles.toArray(lists.get(i)); - INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); - kdTree.insert(ind); - assertEquals(i + 1, kdTree.size()); - } - } - - @Test - public void testDelete() { - int elements = 10; - List digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); - - KDTree kdTree = new KDTree(digits.size()); - List> lists = new ArrayList<>(); - for (int i = 0; i < elements; i++) { - List thisList = new ArrayList<>(digits.size()); - for (int k = 0; k < digits.size(); k++) { - thisList.add(digits.get(k) + i); - } - lists.add(thisList); - } - - INDArray toDelete = Nd4j.empty(DataType.DOUBLE), - leafToDelete = Nd4j.empty(DataType.DOUBLE); - for (int i = 0; i < elements; i++) { - double[] features = Doubles.toArray(lists.get(i)); - INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); - if (i == 1) - toDelete = ind; - if (i == elements - 1) { - leafToDelete = ind; - } - kdTree.insert(ind); - assertEquals(i + 1, kdTree.size()); - } - - kdTree.delete(toDelete); - assertEquals(9, kdTree.size()); - kdTree.delete(leafToDelete); - assertEquals(8, kdTree.size()); - } - - @Test - public void testNN() { - int n = 10; - - // make a KD-tree of dimension {#n} - KDTree kdTree = new KDTree(n); - for (int i = -1; i < n; i++) { - // Insert a unit vector along each dimension - List vec = new ArrayList<>(n); - // i = -1 ensures the origin is in the Tree - for (int k = 0; k < n; k++) { - vec.add((k == i) ? 1.0 : 0.0); - } - INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT); - kdTree.insert(indVec); - } - Random rand = new Random(); - - // random point in the Hypercube - List pt = new ArrayList(n); - for (int k = 0; k < n; k++) { - pt.add(rand.nextDouble()); - } - Pair result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT)); - - // Always true for points in the unitary hypercube - assertTrue(result.getKey() < Double.MAX_VALUE); - - } - - @Test - public void testKNN() { - int dimensions = 512; - int vectorsNo = isIntegrationTests() ? 50000 : 1000; - // make a KD-tree of dimension {#dimensions} - Stopwatch stopwatch = Stopwatch.createStarted(); - KDTree kdTree = new KDTree(dimensions); - for (int i = -1; i < vectorsNo; i++) { - // Insert a unit vector along each dimension - INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); - kdTree.insert(indVec); - } - stopwatch.stop(); - System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); - - Random rand = new Random(); - // random point in the Hypercube - List pt = new ArrayList(dimensions); - for (int k = 0; k < dimensions; k++) { - pt.add(rand.nextFloat() * 10.0); - } - stopwatch.reset(); - stopwatch.start(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); - stopwatch.stop(); - System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); - } - - @Test - public void testKNN_Simple() { - int n = 2; - KDTree kdTree = new KDTree(n); - - float[] data = new float[]{3,3}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{1,1}; - kdTree.insert(Nd4j.createFromArray(data)); - data = new float[]{2,2}; - kdTree.insert(Nd4j.createFromArray(data)); - - data = new float[]{0,0}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); - - assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); - - assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); - - assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - } - - @Test - public void testKNN_1() { - - assertEquals(6, kdTree.size()); - - float[] data = new float[]{8,1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_2() { - float[] data = new float[]{8, 1}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_3() { - - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - - @Test - public void testKNN_4() { - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - } - - @Test - public void testKNN_5() { - float[] data = new float[]{2, 3}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); - assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); - assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); - assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); - assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); - assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); - assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); - assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); - assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); - } - - @Test - public void test_KNN_6() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); - } - - @Test - public void test_KNN_7() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - } - - @Test - public void test_KNN_8() { - float[] data = new float[]{4, 6}; - List> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); - assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); - assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); - assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); - assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); - assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); - assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); - assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); - assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); - assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); - assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); - assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); - assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); - } - - @Test - public void testNoDuplicates() { - int N = 100; - KDTree bigTree = new KDTree(2); - - List points = new ArrayList<>(); - for (int i = 0; i < N; ++i) { - double[] data = new double[]{i, i}; - points.add(Nd4j.createFromArray(data)); - } - - for (int i = 0; i < N; ++i) { - bigTree.insert(points.get(i)); - } - - assertEquals(N, bigTree.size()); - - INDArray node = Nd4j.empty(DataType.DOUBLE); - for (int i = 0; i < N; ++i) { - node = bigTree.delete(node.isEmpty() ? points.get(i) : node); - } - - assertEquals(0, bigTree.size()); - } - - @Ignore - @Test - public void performanceTest() { - int n = 2; - int num = 100000; - // make a KD-tree of dimension {#n} - long start = System.currentTimeMillis(); - KDTree kdTree = new KDTree(n); - INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n); - for (int i = 0 ; i < num; ++i) { - kdTree.insert(inputArrray.getRow(i)); - } - - long end = System.currentTimeMillis(); - Duration duration = new Duration(start, end); - System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); - - List pt = new ArrayList(num); - for (int k = 0; k < n; k++) { - pt.add((float)(num / 2)); - } - start = System.currentTimeMillis(); - List> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); - end = System.currentTimeMillis(); - duration = new Duration(start, end); - long elapsed = end - start; - System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis()); - for (val pair : list) { - System.out.println(pair.getFirst() + " " + pair.getSecond()) ; - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java deleted file mode 100644 index e3a2467ec..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/kmeans/KMeansTest.java +++ /dev/null @@ -1,289 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.kmeans; - -import lombok.val; -import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.algorithm.Distance; -import org.deeplearning4j.clustering.cluster.*; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.List; - -import static org.junit.Assert.*; - -public class KMeansTest extends BaseDL4JTest { - - private boolean[] useKMeansPlusPlus = {true, false}; - - @Override - public long getTimeoutMilliseconds() { - return 60000L; - } - - @Test - public void testKMeans() { - Nd4j.getRandom().setSeed(7); - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(5, 5, Distance.EUCLIDEAN, mode); - List points = Point.toPoints(Nd4j.randn(5, 5)); - ClusterSet clusterSet = kMeansClustering.applyTo(points); - PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); - System.out.println(pointClassification); - } - } - - @Test - public void testKmeansCosine() { - - Nd4j.getRandom().setSeed(7); - int numClusters = 5; - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode); - List points = Point.toPoints(Nd4j.rand(5, 300)); - ClusterSet clusterSet = kMeansClustering.applyTo(points); - PointClassification pointClassification = clusterSet.classifyPoint(points.get(0)); - - - KMeansClustering kMeansClusteringEuclidean = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode); - ClusterSet clusterSetEuclidean = kMeansClusteringEuclidean.applyTo(points); - PointClassification pointClassificationEuclidean = clusterSetEuclidean.classifyPoint(points.get(0)); - System.out.println("Cosine " + pointClassification); - System.out.println("Euclidean " + pointClassificationEuclidean); - - assertEquals(pointClassification.getCluster().getPoints().get(0), - pointClassificationEuclidean.getCluster().getPoints().get(0)); - } - } - - @Ignore - @Test - public void testPerformanceAllIterations() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 20; - for (boolean mode : useKMeansPlusPlus) { - StopWatch watch = new StopWatch(); - watch.start(); - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.COSINE_DISTANCE, mode); - List points = Point.toPoints(Nd4j.linspace(0, 5000 * 300, 5000 * 300).reshape(5000, 300)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - } - } - - @Test - @Ignore - public void testPerformanceWithConvergence() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 20; - for (boolean mode : useKMeansPlusPlus) { - StopWatch watch = new StopWatch(); - watch.start(); - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, Distance.COSINE_DISTANCE, false, mode); - - List points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - - watch.reset(); - watch.start(); - kMeansClustering = KMeansClustering.setup(numClusters, 0.05, Distance.COSINE_DISTANCE, false, mode); - - points = Point.toPoints(Nd4j.linspace(0, 10000 * 300, 10000 * 300).reshape(10000, 300)); - - clusterSet = kMeansClustering.applyTo(points); - watch.stop(); - System.out.println("Elapsed for clustering : " + watch); - - watch.reset(); - watch.start(); - for (Point p : points) { - PointClassification pointClassification = clusterSet.classifyPoint(p); - } - watch.stop(); - System.out.println("Elapsed for search: " + watch); - } - } - - @Test - public void testCorrectness() { - - /*for (int c = 0; c < 10; ++c)*/ { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - int numClusters = 3; - for (boolean mode : useKMeansPlusPlus) { - KMeansClustering kMeansClustering = KMeansClustering.setup(numClusters, 1000, Distance.EUCLIDEAN, mode); - double[] data = new double[]{ - 15, 16, - 16, 18.5, - 17, 20.2, - 16.4, 17.12, - 17.23, 18.12, - 43, 43, - 44.43, 45.212, - 45.8, 54.23, - 46.313, 43.123, - 50.21, 46.3, - 99, 99.22, - 100.32, 98.123, - 100.32, 97.423, - 102, 93.23, - 102.23, 94.23 - }; - List points = Point.toPoints(Nd4j.createFromArray(data).reshape(15, 2)); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - - - INDArray row0 = Nd4j.createFromArray(new double[]{16.6575, 18.4850}); - INDArray row1 = Nd4j.createFromArray(new double[]{32.6050, 31.1500}); - INDArray row2 = Nd4j.createFromArray(new double[]{75.9348, 74.1990}); - - /*List clusters = clusterSet.getClusters(); - assertEquals(row0, clusters.get(0).getCenter().getArray()); - assertEquals(row1, clusters.get(1).getCenter().getArray()); - assertEquals(row2, clusters.get(2).getCenter().getArray());*/ - - PointClassification pointClassification = null; - for (Point p : points) { - pointClassification = clusterSet.classifyPoint(p); - System.out.println("Point: " + p.getArray() + " " + " assigned to cluster: " + pointClassification.getCluster().getCenter().getArray()); - List clusters = clusterSet.getClusters(); - for (int i = 0; i < clusters.size(); ++i) - System.out.println("Choice: " + clusters.get(i).getCenter().getArray()); - } - } - /*assertEquals(Nd4j.createFromArray(new double[]{75.9348, 74.1990}), - pointClassification.getCluster().getCenter().getArray());*/ - - /*clusters = clusterSet.getClusters(); - assertEquals(row0, clusters.get(0).getCenter().getArray()); - assertEquals(row1, clusters.get(1).getCenter().getArray()); - assertEquals(row2, clusters.get(2).getCenter().getArray());*/ - } - } - - @Test - public void testCentersHolder() { - int rows = 3, cols = 2; - CentersHolder ch = new CentersHolder(rows, cols); - - INDArray row0 = Nd4j.createFromArray(new double[]{16.4000, 17.1200}); - INDArray row1 = Nd4j.createFromArray(new double[]{45.8000, 54.2300}); - INDArray row2 = Nd4j.createFromArray(new double[]{95.9348, 94.1990}); - - ch.addCenter(row0); - ch.addCenter(row1); - ch.addCenter(row2); - - double[] data = new double[]{ - 15, 16, - 16, 18.5, - 17, 20.2, - 16.4, 17.12, - 17.23, 18.12, - 43, 43, - 44.43, 45.212, - 45.8, 54.23, - 46.313, 43.123, - 50.21, 46.3, - 99, 99.22, - 100.32, 98.123, - 100.32, 97.423, - 102, 93.23, - 102.23, 94.23 - }; - - INDArray pointData = Nd4j.createFromArray(data); - List points = Point.toPoints(pointData.reshape(15,2)); - - for (int i = 0 ; i < points.size(); ++i) { - INDArray dist = ch.getMinDistances(points.get(i), Distance.EUCLIDEAN); - System.out.println("Point: " + points.get(i).getArray()); - System.out.println("Centers: " + ch.getCenters()); - System.out.println("Distance: " + dist); - System.out.println(); - } - } - - @Test - public void testInitClusters() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - Nd4j.getRandom().setSeed(7); - { - KMeansClustering kMeansClustering = KMeansClustering.setup(5, 1, Distance.EUCLIDEAN, true); - - double[][] dataArray = {{1000000.0, 2.8E7, 5.5E7, 8.2E7}, {2.8E7, 5.5E7, 8.2E7, 1.09E8}, {5.5E7, 8.2E7, 1.09E8, 1.36E8}, - {8.2E7, 1.09E8, 1.36E8, 1.63E8}, {1.09E8, 1.36E8, 1.63E8, 1.9E8}, {1.36E8, 1.63E8, 1.9E8, 2.17E8}, - {1.63E8, 1.9E8, 2.17E8, 2.44E8}, {1.9E8, 2.17E8, 2.44E8, 2.71E8}, {2.17E8, 2.44E8, 2.71E8, 2.98E8}, - {2.44E8, 2.71E8, 2.98E8, 3.25E8}, {2.71E8, 2.98E8, 3.25E8, 3.52E8}, {2.98E8, 3.25E8, 3.52E8, 3.79E8}, - {3.25E8, 3.52E8, 3.79E8, 4.06E8}, {3.52E8, 3.79E8, 4.06E8, 4.33E8}, {3.79E8, 4.06E8, 4.33E8, 4.6E8}, - {4.06E8, 4.33E8, 4.6E8, 4.87E8}, {4.33E8, 4.6E8, 4.87E8, 5.14E8}, {4.6E8, 4.87E8, 5.14E8, 5.41E8}, - {4.87E8, 5.14E8, 5.41E8, 5.68E8}, {5.14E8, 5.41E8, 5.68E8, 5.95E8}, {5.41E8, 5.68E8, 5.95E8, 6.22E8}, - {5.68E8, 5.95E8, 6.22E8, 6.49E8}, {5.95E8, 6.22E8, 6.49E8, 6.76E8}, {6.22E8, 6.49E8, 6.76E8, 7.03E8}, - {6.49E8, 6.76E8, 7.03E8, 7.3E8}, {6.76E8, 7.03E8, 7.3E8, 7.57E8}, {7.03E8, 7.3E8, 7.57E8, 7.84E8}}; - INDArray data = Nd4j.createFromArray(dataArray); - List points = Point.toPoints(data); - - ClusterSet clusterSet = kMeansClustering.applyTo(points); - - double[] centroid1 = {2.44e8, 2.71e8, 2.98e8, 3.25e8}; - double[] centroid2 = {1000000.0, 2.8E7, 5.5E7, 8.2E7}; - double[] centroid3 = {5.95E8, 6.22e8, 6.49e8, 6.76e8}; - double[] centroid4 = {3.79E8, 4.06E8, 4.33E8, 4.6E8}; - double[] centroid5 = {5.5E7, 8.2E7, 1.09E8, 1.36E8}; - - assertArrayEquals(centroid1, clusterSet.getClusters().get(0).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid2, clusterSet.getClusters().get(1).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid3, clusterSet.getClusters().get(2).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid4, clusterSet.getClusters().get(3).getCenter().getArray().toDoubleVector(), 1e-4); - assertArrayEquals(centroid5, clusterSet.getClusters().get(4).getCenter().getArray().toDoubleVector(), 1e-4); - } - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java deleted file mode 100644 index 105dd368a..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSHTest.java +++ /dev/null @@ -1,215 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.lsh; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.After; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.Random; - -import static org.junit.Assert.assertArrayEquals; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class RandomProjectionLSHTest extends BaseDL4JTest { - - int hashLength = 31; - int numTables = 2; - int intDimensions = 13; - - RandomProjectionLSH rpLSH; - INDArray e1; - INDArray inputs; - - @Before - public void setUp() { - Nd4j.getRandom().setSeed(12345); - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - rpLSH = new RandomProjectionLSH(hashLength, numTables, intDimensions, 0.1f); - inputs = Nd4j.rand(DataType.DOUBLE, 100, intDimensions); - e1 = Nd4j.ones(DataType.DOUBLE, 1, intDimensions); - } - - - @After - public void tearDown() { inputs = null; } - - @Test - public void testEntropyDims(){ - assertArrayEquals(new long[]{numTables, intDimensions}, rpLSH.entropy(e1).shape()); - } - - @Test - public void testHashDims(){ - assertArrayEquals(new long[]{1, hashLength}, rpLSH.hash(e1).shape()); - } - - @Test - public void testHashDimsMultiple(){ - INDArray data = Nd4j.ones(1, intDimensions); - assertArrayEquals(new long[]{1, hashLength}, rpLSH.hash(data).shape()); - - data = Nd4j.ones(100, intDimensions); - assertArrayEquals(new long[]{100, hashLength}, rpLSH.hash(data).shape()); - } - - @Test - public void testSigNums(){ - assertEquals(1.0f, rpLSH.hash(e1).aminNumber().floatValue(),1e-3f); - } - - - @Test - public void testIndexDims(){ - rpLSH.makeIndex(Nd4j.rand(100, intDimensions)); - assertArrayEquals(new long[]{100, hashLength}, rpLSH.index.shape()); - } - - - @Test - public void testGetRawBucketOfDims(){ - rpLSH.makeIndex(inputs); - assertArrayEquals(new long[]{100}, rpLSH.rawBucketOf(e1).shape()); - } - - @Test - public void testRawBucketOfReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - assertEquals(1.0f, rpLSH.rawBucketOf(row).maxNumber().floatValue(), 1e-3f); - } - - @Test - public void testBucketDims(){ - rpLSH.makeIndex(inputs); - assertArrayEquals(new long[]{100}, rpLSH.bucket(e1).shape()); - } - - @Test - public void testBucketReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - assertEquals(1.0f, rpLSH.bucket(row).maxNumber().floatValue(), 1e-3f); - } - - - @Test - public void testBucketDataReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray bucketData = rpLSH.bucketData(row); - - assertEquals(intDimensions, bucketData.shape()[1]); - assertTrue(1 <= bucketData.shape()[0]); - } - - @Test - public void testBucketDataReflexive(){ - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray bucketData = rpLSH.bucketData(row); - - INDArray res = Nd4j.zeros(DataType.BOOL, bucketData.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(bucketData, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one bucket content to be the query %s, but found %s", row, rpLSH.bucket(row)), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - - - @Test - public void testSearchReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray searchResults = rpLSH.search(row, 10.0f); - - assertTrue( - String.format("Expected the search to return at least one result, the query %s but found %s yielding %d results", row, searchResults, searchResults.shape()[0]), - searchResults.shape()[0] >= 1); - } - - - @Test - public void testSearchReflexive() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - - INDArray searchResults = rpLSH.search(row, 10.0f); - - - INDArray res = Nd4j.zeros(DataType.BOOL, searchResults.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(searchResults, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one search result to be the query %s, but found %s", row, searchResults), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - - - - @Test - public void testANNSearchReflexiveDimensions() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx, true); - INDArray searchResults = rpLSH.search(row, 100); - - assertTrue( - String.format("Expected the search to return at least one result, the query %s but found %s yielding %d results", row, searchResults, searchResults.shape()[0]), - searchResults.shape()[0] >= 1); - } - - - @Test - public void testANNSearchReflexive() { - rpLSH.makeIndex(inputs); - int idx = (new Random(12345)).nextInt(100); - INDArray row = inputs.getRow(idx).reshape(1, intDimensions); - - INDArray searchResults = rpLSH.search(row, 100); - - - INDArray res = Nd4j.zeros(DataType.BOOL, searchResults.shape()); - Nd4j.getExecutioner().exec(new BroadcastEqualTo(searchResults, row, res, -1)); - res = res.castTo(DataType.FLOAT); - - assertEquals( - String.format("Expected one search result to be the query %s, but found %s", row, searchResults), - 1.0f, res.min(-1).maxNumber().floatValue(), 1e-3f); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java deleted file mode 100644 index 0cb77bd1d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/quadtree/QuadTreeTest.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.quadtree; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -public class QuadTreeTest extends BaseDL4JTest { - - @Test - public void testQuadTree() { - INDArray n = Nd4j.ones(3, 2); - n.slice(1).addi(1); - n.slice(2).addi(2); - QuadTree quadTree = new QuadTree(n); - assertEquals(n.rows(), quadTree.getCumSize()); - assertTrue(quadTree.isCorrect()); - - - - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java deleted file mode 100644 index abb55a7fd..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPTreeTest.java +++ /dev/null @@ -1,101 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.dataset.api.DataSet; -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerMinMaxScaler; -import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize; -import org.nd4j.linalg.factory.Nd4j; - -import java.util.List; - -import static org.junit.Assert.*; - -public class RPTreeTest extends BaseDL4JTest { - - @Before - public void setUp() { - Nd4j.setDefaultDataTypes(DataType.FLOAT, DataType.FLOAT); - } - - - @Test - public void testRPTree() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(150,150); - RPTree rpTree = new RPTree(784,50); - DataSet d = mnist.next(); - NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); - normalizerStandardize.fit(d); - normalizerStandardize.transform(d.getFeatures()); - INDArray data = d.getFeatures(); - rpTree.buildTree(data); - assertEquals(4,rpTree.getLeaves().size()); - assertEquals(0,rpTree.getRoot().getDepth()); - - List candidates = rpTree.getCandidates(data.getRow(0)); - assertFalse(candidates.isEmpty()); - assertEquals(10,rpTree.query(data.slice(0),10).length()); - System.out.println(candidates.size()); - - rpTree.addNodeAtIndex(150,data.getRow(0)); - - } - - @Test - public void testFindSelf() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(100, 6000); - NormalizerMinMaxScaler minMaxNormalizer = new NormalizerMinMaxScaler(0, 1); - minMaxNormalizer.fit(mnist); - DataSet d = mnist.next(); - minMaxNormalizer.transform(d.getFeatures()); - RPForest rpForest = new RPForest(100, 100, "euclidean"); - rpForest.fit(d.getFeatures()); - for (int i = 0; i < 10; i++) { - INDArray indexes = rpForest.queryAll(d.getFeatures().slice(i), 10); - assertEquals(i,indexes.getInt(0)); - } - } - - @Test - public void testRpTreeMaxNodes() throws Exception { - DataSetIterator mnist = new MnistDataSetIterator(150,150); - RPForest rpTree = new RPForest(4,4,"euclidean"); - DataSet d = mnist.next(); - NormalizerStandardize normalizerStandardize = new NormalizerStandardize(); - normalizerStandardize.fit(d); - rpTree.fit(d.getFeatures()); - for(RPTree tree : rpTree.getTrees()) { - for(RPNode node : tree.getLeaves()) { - assertTrue(node.getIndices().size() <= rpTree.getMaxSize()); - } - } - - } - - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java deleted file mode 100644 index 18ca2ac9d..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/randomprojection/RPUtilsTest.java +++ /dev/null @@ -1,45 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.randomprojection; - -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import static org.junit.Assert.assertEquals; - -public class RPUtilsTest extends BaseDL4JTest { - - @Test - public void testDistanceComputeBatch() { - INDArray x = Nd4j.linspace(1,4,4, Nd4j.dataType()).reshape(1, 4); - INDArray y = Nd4j.linspace(1,16,16, Nd4j.dataType()).reshape(4,4); - INDArray result = Nd4j.create(1, 4); - INDArray distances = RPUtils.computeDistanceMulti("euclidean",x,y,result); - INDArray scalarResult = Nd4j.scalar(1.0); - for(int i = 0; i < result.length(); i++) { - double dist = RPUtils.computeDistance("euclidean",x,y.slice(i),scalarResult); - assertEquals(dist,distances.getDouble(i),1e-3); - } - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java deleted file mode 100644 index 0ac39083b..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/sptree/SPTreeTest.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.sptree; - -import org.apache.commons.lang3.time.StopWatch; -import org.deeplearning4j.BaseDL4JTest; -import org.junit.Before; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.buffer.util.DataTypeUtil; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.guava.util.concurrent.AtomicDouble; - -import static org.junit.Assert.*; - -/** - * @author Adam Gibson - */ -public class SPTreeTest extends BaseDL4JTest { - - @Override - public long getTimeoutMilliseconds() { - return 120000L; - } - - @Before - public void setUp() { - DataTypeUtil.setDTypeForContext(DataType.DOUBLE); - } - - @Test - public void testStructure() { - INDArray data = Nd4j.create(new double[][] {{1, 2, 3}, {4, 5, 6}}); - SpTree tree = new SpTree(data); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - assertEquals(Nd4j.create(new double[]{2.5f, 3.5f, 4.5f}), tree.getCenterOfMass()); - assertEquals(2, tree.getCumSize()); - assertEquals(8, tree.getNumChildren()); - assertTrue(tree.isCorrect()); - } - } - - @Test - public void testComputeEdgeForces() { - Nd4j.setDefaultDataTypes(DataType.DOUBLE, DataType.DOUBLE); - double[] aData = new double[]{ - 0.2999816948164936, 0.26252049735806526, 0.2673853427498767, 0.8604464129156685, 0.4802652829902563, 0.10959096539488711, 0.7950242948008909, 0.5917848948003486, - 0.2738285999345498, 0.9519684328285567, 0.9690024759209738, 0.8585615547624705, 0.8087760944312002, 0.5337951589543348, 0.5960876109129123, 0.7187130179825856, - 0.4629777327445964, 0.08665909175584818, 0.7748005397731237, 0.48020186965468536, 0.24927351841378798, 0.32272599988270445, 0.306414968984427, 0.6980212149215657, - 0.7977183964212472, 0.7673513094629704, 0.1679681724796478, 0.3107359484804584, 0.021701726051792103, 0.13797462786662518, 0.8618953518813538, 0.841333838365635, - 0.5284957375170422, 0.9703367685039823, 0.677388096913733, 0.2624474979832243, 0.43740966353106536, 0.15685545957858893, 0.11072929134449871, 0.06007395961283357, - 0.4093918718557811, 0.9563909195720572, 0.5994144944480242, 0.8278927844215804, 0.38586830957105667, 0.6201844716257464, 0.7603829079070265, 0.07875691596842949, - 0.08651136699915507, 0.7445210640026082, 0.6547649514127559, 0.3384719042666908, 0.05816723105860,0.6248951423054205, 0.7431868493349041}; - INDArray data = Nd4j.createFromArray(aData).reshape(11,5); - INDArray rows = Nd4j.createFromArray(new int[]{ - 0, 9, 18, 27, 36, 45, 54, 63, 72, 81, 90, 99}); - INDArray cols = Nd4j.createFromArray(new int[]{ - 4, 3, 10, 8, 6, 7, 1, 5, 9, 4, 9, 8, 10, 2, 0, 6, 7, 3, 6, 8, 3, 9, 10, 1, 4, 0, 5, 10, 0, 4, 6, 8, 9, 2, 5, 7, 0, 10, 3, 1, 8, 9, 6, 7, 2, 7, 9, 3, 10, 0, 4, 2, 8, 1, 2, 8, 3, 10, 0, 4, 9, 1, 5, 5, 9, 0, 3, 10, 4, 8, 1, 2, 6, 2, 0, 3, 4, 1, 10, 9, 7, 10, 1, 3, 7, 4, 5, 2, 8, 6, 3, 4, 0, 9, 6, 5, 8, 7, 1}); - INDArray vals = Nd4j.createFromArray(new double[] - { 0.6806, 0.1978, 0.1349, 0.0403, 0.0087, 0.0369, 0.0081, 0.0172, 0.0014, 0.0046, 0.0081, 0.3375, 0.2274, 0.0556, 0.0098, 0.0175, 0.0027, 0.0077, 0.0014, 0.0023, 0.0175, 0.6569, 0.1762, 0.0254, 0.0200, 0.0118, 0.0074, 0.0046, 0.0124, 0.0012, 0.1978, 0.0014, 0.0254, 0.7198, 0.0712, 0.0850, 0.0389, 0.0555, 0.0418, 0.0286, 0.6806, 0.3375, 0.0074, 0.0712, 0.2290, 0.0224, 0.0189, 0.0080, 0.0187, 0.0097, 0.0172, 0.0124, 0.0418, 0.7799, 0.0521, 0.0395, 0.0097, 0.0030, 0.0023, 1.706e-5, 0.0087, 0.0027, 0.6569, 0.0850, 0.0080, 0.5562, 0.0173, 0.0015, 1.706e-5, 0.0369, 0.0077, 0.0286, 0.0187, 0.7799, 0.0711, 0.0200, 0.0084, 0.0012, 0.0403, 0.0556, 0.1762, 0.0389, 0.0224, 0.0030, 0.5562, 0.0084, 0.0060, 0.0028, 0.0014, 0.2274, 0.0200, 0.0555, 0.0189, 0.0521, 0.0015, 0.0711, 0.0028, 0.3911, 0.1349, 0.0098, 0.0118, 0.7198, 0.2290, 0.0395, 0.0173, 0.0200, 0.0060, 0.3911}); - SpTree tree = new SpTree(data); - INDArray posF = Nd4j.create(11, 5); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - tree.computeEdgeForces(rows, cols, vals, 11, posF); - } - INDArray expected = Nd4j.createFromArray(new double[]{ -0.08045664291717945, -0.1010737980370276, 0.01793326162563703, 0.16108447776416351, -0.20679423033936287, -0.15788549368713395, 0.02546624825966788, 0.062309466206907055, -0.165806093080134, 0.15266225270841186, 0.17508365896345726, 0.09588570563583201, 0.34124767300538084, 0.14606666020839956, -0.06786563815470595, -0.09326646571247202, -0.19896040730569928, -0.3618837364446506, 0.13946315445146712, -0.04570186310149667, -0.2473462951783839, -0.41362278505023914, -0.1094083777758208, 0.10705807646770374, 0.24462088260113946, 0.21722270026621748, -0.21799892431326567, -0.08205544003080587, -0.11170161709042685, -0.2674768703060442, 0.03617747284043274, 0.16430316252598698, 0.04552845070022399, 0.2593696744801452, 0.1439989190892037, -0.059339471967457376, 0.05460893792863096, -0.0595168036583193, -0.2527693197519917, -0.15850951859835274, -0.2945536856938165, 0.15434659331638875, -0.022910846947667776, 0.23598009757792854, -0.11149279745674007, 0.09670616593772939, 0.11125703954547914, -0.08519984596392606, -0.12779827002328714, 0.23025192887225998, 0.13741473964038722, -0.06193553503816597, -0.08349781586292176, 0.1622156410642145, 0.155975447743472}).reshape(11,5); - for (int i = 0; i < 11; ++i) - assertArrayEquals(expected.getRow(i).toDoubleVector(), posF.getRow(i).toDoubleVector(), 1e-2); - - AtomicDouble sumQ = new AtomicDouble(0.0); - /*try (MemoryWorkspace ws = tree.workspace().notifyScopeEntered())*/ { - tree.computeNonEdgeForces(0, 0.5, Nd4j.zeros(5), sumQ); - } - assertEquals(8.65, sumQ.get(), 1e-2); - } - - @Test - //@Ignore - public void testLargeTree() { - int num = isIntegrationTests() ? 100000 : 1000; - StopWatch watch = new StopWatch(); - watch.start(); - INDArray arr = Nd4j.linspace(1, num, num, Nd4j.dataType()).reshape(num, 1); - SpTree tree = new SpTree(arr); - watch.stop(); - System.out.println("Tree of size " + num + " created in " + watch); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java deleted file mode 100644 index 86d34b603..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VPTreeSerializationTests.java +++ /dev/null @@ -1,116 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.apache.commons.lang3.SerializationUtils; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.junit.Ignore; -import org.junit.Test; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.ByteArrayInputStream; -import java.io.ByteArrayOutputStream; -import java.util.ArrayList; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; - -@Slf4j -public class VPTreeSerializationTests extends BaseDL4JTest { - - @Test - public void testSerialization_1() throws Exception { - val points = Nd4j.rand(new int[] {10, 15}); - val treeA = new VPTree(points, true, 2); - - try (val bos = new ByteArrayOutputStream()) { - SerializationUtils.serialize(treeA, bos); - - try (val bis = new ByteArrayInputStream(bos.toByteArray())) { - VPTree treeB = SerializationUtils.deserialize(bis); - - assertEquals(points, treeA.getItems()); - assertEquals(points, treeB.getItems()); - - assertEquals(treeA.getWorkers(), treeB.getWorkers()); - - val row = points.getRow(1).dup('c'); - - val dpListA = new ArrayList(); - val dListA = new ArrayList(); - - val dpListB = new ArrayList(); - val dListB = new ArrayList(); - - treeA.search(row, 3, dpListA, dListA); - treeB.search(row, 3, dpListB, dListB); - - assertTrue(dpListA.size() != 0); - assertTrue(dListA.size() != 0); - - assertEquals(dpListA.size(), dpListB.size()); - assertEquals(dListA.size(), dListB.size()); - - for (int e = 0; e < dpListA.size(); e++) { - val rA = dpListA.get(e).getPoint(); - val rB = dpListB.get(e).getPoint(); - - assertEquals(rA, rB); - } - } - } - } - - - @Test - public void testNewConstructor_1() { - val points = Nd4j.rand(new int[] {10, 15}); - val treeA = new VPTree(points, true, 2); - - val rows = Nd4j.tear(points, 1); - - val list = new ArrayList(); - - int idx = 0; - for (val r: rows) - list.add(new DataPoint(idx++, r)); - - val treeB = new VPTree(list); - - assertEquals(points, treeA.getItems()); - assertEquals(points, treeB.getItems()); - } - - @Test - @Ignore - public void testBigTrees_1() throws Exception { - val list = new ArrayList(); - - for (int e = 0; e < 3200000; e++) { - val dp = new DataPoint(e, Nd4j.rand(new long[] {1, 300})); - } - - log.info("DataPoints created"); - } -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java deleted file mode 100644 index d5ced0cd2..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/test/java/org/deeplearning4j/clustering/vptree/VpTreeNodeTest.java +++ /dev/null @@ -1,414 +0,0 @@ -/* - * ****************************************************************************** - * * - * * - * * This program and the accompanying materials are made available under the - * * terms of the Apache License, Version 2.0 which is available at - * * https://www.apache.org/licenses/LICENSE-2.0. - * * - * * See the NOTICE file distributed with this work for additional - * * information regarding copyright ownership. - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * * License for the specific language governing permissions and limitations - * * under the License. - * * - * * SPDX-License-Identifier: Apache-2.0 - * ***************************************************************************** - */ - -package org.deeplearning4j.clustering.vptree; - -import lombok.extern.slf4j.Slf4j; -import lombok.val; -import org.deeplearning4j.BaseDL4JTest; -import org.deeplearning4j.clustering.sptree.DataPoint; -import org.joda.time.Duration; -import org.junit.BeforeClass; -import org.junit.Test; -import org.nd4j.linalg.api.buffer.DataType; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.exception.ND4JIllegalStateException; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.common.primitives.Counter; -import org.nd4j.common.primitives.Pair; - -import java.util.*; - -import static org.junit.Assert.*; - -/** - * @author Anatoly Borisov - */ -@Slf4j -public class VpTreeNodeTest extends BaseDL4JTest { - - - private static class DistIndex implements Comparable { - public double dist; - public int index; - - public int compareTo(DistIndex r) { - return Double.compare(dist, r.dist); - } - } - - @BeforeClass - public static void beforeClass(){ - Nd4j.setDataType(DataType.FLOAT); - } - - @Test - public void testKnnK() { - INDArray arr = Nd4j.randn(10, 5); - VPTree t = new VPTree(arr, false); - List resultList = new ArrayList<>(); - List distances = new ArrayList<>(); - t.search(arr.getRow(0), 5, resultList, distances); - assertEquals(5, resultList.size()); - } - - - @Test - public void testParallel_1() { - int k = 5; - - for (int e = 0; e < 5; e++) { - Nd4j.getRandom().setSeed(7); - INDArray randn = Nd4j.rand(100, 3); - VPTree vpTree = new VPTree(randn, false, 4); - Nd4j.getRandom().setSeed(7); - VPTree vpTreeNoParallel = new VPTree(randn, false, 1); - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - List noParallelResults = new ArrayList<>(); - List noDistances = new ArrayList<>(); - vpTree.search(randn.getRow(0), k, results, distances, true); - vpTreeNoParallel.search(randn.getRow(0), k, noParallelResults, noDistances, true); - - assertEquals("Failed at iteration " + e, k, results.size()); - assertEquals("Failed at iteration " + e, noParallelResults.size(), results.size()); - assertNotEquals(randn.getRow(0, true), results.get(0).getPoint()); - assertEquals("Failed at iteration " + e, noParallelResults, results); - assertEquals("Failed at iteration " + e, noDistances, distances); - } - } - - @Test - public void testParallel_2() { - int k = 5; - - for (int e = 0; e < 5; e++) { - Nd4j.getRandom().setSeed(7); - INDArray randn = Nd4j.rand(100, 3); - VPTree vpTree = new VPTree(randn, false, 4); - Nd4j.getRandom().setSeed(7); - VPTree vpTreeNoParallel = new VPTree(randn, false, 1); - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - List noParallelResults = new ArrayList<>(); - List noDistances = new ArrayList<>(); - vpTree.search(randn.getRow(0), k, results, distances, false); - vpTreeNoParallel.search(randn.getRow(0), k, noParallelResults, noDistances, false); - - assertEquals("Failed at iteration " + e, k, results.size()); - assertEquals("Failed at iteration " + e, noParallelResults.size(), results.size()); - assertEquals(randn.getRow(0, true), results.get(0).getPoint()); - assertEquals("Failed at iteration " + e, noParallelResults, results); - assertEquals("Failed at iteration " + e, noDistances, distances); - } - } - - @Test - public void testReproducibility() { - val results = new ArrayList(); - val distances = new ArrayList(); - Nd4j.getRandom().setSeed(7); - val randn = Nd4j.rand(1000, 100); - - for (int e = 0; e < 10; e++) { - Nd4j.getRandom().setSeed(7); - val vpTree = new VPTree(randn, false, 1); - - val cresults = new ArrayList(); - val cdistances = new ArrayList(); - vpTree.search(randn.getRow(0), 5, cresults, cdistances); - - if (e == 0) { - results.addAll(cresults); - distances.addAll(cdistances); - } else { - assertEquals("Failed at iteration " + e, results, cresults); - assertEquals("Failed at iteration " + e, distances, cdistances); - } - } - } - - @Test - public void knnManualRandom() { - knnManual(Nd4j.randn(3, 5)); - } - - @Test - public void knnManualNaturals() { - knnManual(generateNaturalsMatrix(20, 2)); - } - - public static void knnManual(INDArray arr) { - Nd4j.getRandom().setSeed(7); - VPTree t = new VPTree(arr, false); - int k = 1; - int m = arr.rows(); - for (int targetIndex = 0; targetIndex < m; targetIndex++) { - // Do an exhaustive search - TreeSet s = new TreeSet<>(); - INDArray query = arr.getRow(targetIndex, true); - - Counter counter = new Counter<>(); - for (int j = 0; j < m; j++) { - double d = t.distance(query, (arr.getRow(j, true))); - counter.setCount(j, (float) d); - - } - - PriorityQueue> pq = counter.asReversedPriorityQueue(); - // keep closest k - for (int i = 0; i < k; i++) { - Pair di = pq.poll(); - System.out.println("exhaustive d=" + di.getFirst()); - s.add(di.getFirst()); - } - - // Check what VPTree gives for results - List results = new ArrayList<>(); - VPTreeFillSearch fillSearch = new VPTreeFillSearch(t, k, query); - fillSearch.search(); - results = fillSearch.getResults(); - - //List items = t.getItems(); - TreeSet resultSet = new TreeSet<>(); - - // keep k in a set - for (int i = 0; i < k; ++i) { - DataPoint result = results.get(i); - int r = result.getIndex(); - resultSet.add(r); - } - - - - // check - for (int r : resultSet) { - INDArray expectedResult = arr.getRow(r, true); - if (!s.contains(r)) { - fillSearch = new VPTreeFillSearch(t, k, query); - fillSearch.search(); - results = fillSearch.getResults(); - } - assertTrue(String.format( - "VPTree result" + " %d is not in the " + "closest %d " + " " + "from the exhaustive" - + " search with query point %s and " - + "result %s and target not found %s", - r, k, query.toString(), results.toString(), expectedResult.toString()), s.contains(r)); - } - - } - } - - @Test - public void vpTreeTest() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - List add = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(Nd4j.create(new double[] {50, 50}), 1, add, distances); - DataPoint assertion = add.get(0); - assertEquals(new DataPoint(0, Nd4j.create(new double[] {55, 55}).reshape(1,2)), assertion); - - tree.search(Nd4j.create(new double[] {61, 61}), 2, add, distances, false); - assertion = add.get(0); - assertEquals(Nd4j.create(new double[] {60, 60}).reshape(1,2), assertion.getPoint()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest2() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(1, 10), 2, new ArrayList(), new ArrayList()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest3() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(2, 10), 2, new ArrayList(), new ArrayList()); - } - - @Test(expected = ND4JIllegalStateException.class) - public void vpTreeTest4() { - List points = new ArrayList<>(); - points.add(new DataPoint(0, Nd4j.create(new double[] {55, 55}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {60, 60}))); - points.add(new DataPoint(2, Nd4j.create(new double[] {65, 65}))); - VPTree tree = new VPTree(points, "euclidean"); - - tree.search(Nd4j.create(2, 10, 10), 2, new ArrayList(), new ArrayList()); - } - - public static INDArray generateNaturalsMatrix(int nrows, int ncols) { - INDArray col = Nd4j.arange(0, nrows).reshape(nrows, 1).castTo(DataType.DOUBLE); - INDArray points = Nd4j.create(DataType.DOUBLE, nrows, ncols); - if (points.isColumnVectorOrScalar()) - points = col.dup(); - else { - for (int i = 0; i < ncols; i++) - points.putColumn(i, col); - } - return points; - } - - @Test - public void testVPSearchOverNaturals1D() throws Exception { - testVPSearchOverNaturalsPD(20, 1, 5); - } - - @Test - public void testVPSearchOverNaturals2D() throws Exception { - testVPSearchOverNaturalsPD(20, 2, 5); - } - - @Test - public void testTreeOrder() { - - int N = 10, dim = 1; - INDArray dataset = Nd4j.randn(N, dim); - double[] rawData = dataset.toDoubleVector(); - Arrays.sort(dataset.toDoubleVector()); - dataset = Nd4j.createFromArray(rawData).reshape(1,N); - - List points = new ArrayList<>(); - - for (int i = 0; i < rawData.length; ++i) { - points.add(new DataPoint(i, Nd4j.create(new double[]{rawData[i]}))); - } - - VPTree tree = new VPTree(points, "euclidean"); - INDArray points1 = tree.getItems(); - assertEquals(dataset, points1); - } - - @Test - public void testNearestNeighbors() { - - List points = new ArrayList<>(); - - points.add(new DataPoint(0, Nd4j.create(new double[] {0.83494041, 1.70294823, -1.34172191, 0.02350972, - -0.87519361, 0.64401935, -0.5634212, -1.1274308, - 0.19245948, -0.11349026}))); - points.add(new DataPoint(1, Nd4j.create(new double[] {-0.41115537, -0.7686138, -0.67923172, 1.01638281, - 0.04390801, 0.29753166, 0.78915771, -0.13564866, - -1.06053692, -0.15953041}))); - - VPTree tree = new VPTree(points, "euclidean"); - - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - - final int k = 1; - double[] input = new double[]{0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5}; - tree.search(Nd4j.createFromArray(input), k, results, distances); - assertEquals(k, distances.size()); - assertEquals(2.7755637844503016, distances.get(0), 1e-5); - - double[] results_pattern = new double[]{-0.41115537, -0.7686138 , -0.67923172, 1.01638281, 0.04390801, - 0.29753166, 0.78915771, -0.13564866, -1.06053692, -0.15953041}; - for (int i = 0; i < results_pattern.length; ++i) { - assertEquals(results_pattern[i], results.get(0).getPoint().getDouble(i), 1e-5); - } - } - - @Test - public void performanceTest() { - final int dim = 300; - final int rows = 8000; - final int k = 5; - - INDArray inputArrray = Nd4j.linspace(DataType.DOUBLE, 0.0, 1.0, rows * dim).reshape(rows, dim); - - //INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, 200000, dim); - long start = System.currentTimeMillis(); - VPTree tree = new VPTree(inputArrray, "euclidean"); - long end = System.currentTimeMillis(); - Duration duration = new Duration(start, end); - System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds()); - - double[] input = new double[dim]; - for (int i = 0; i < dim; ++i) { - input[i] = 119; - } - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - start = System.currentTimeMillis(); - tree.search(Nd4j.createFromArray(input), k, results, distances); - end = System.currentTimeMillis(); - duration = new Duration(start, end); - System.out.println("Elapsed time for tree search " + duration.getStandardSeconds()); - assertEquals(1590.2987519949422, distances.get(0), 1e-4); - } - - public static void testVPSearchOverNaturalsPD(int nrows, int ncols, int K) throws Exception { - final int queryPoint = 12; - - INDArray points = generateNaturalsMatrix(nrows, ncols); - INDArray query = Nd4j.zeros(DataType.DOUBLE, 1, ncols); - for (int i = 0; i < ncols; i++) - query.putScalar(0, i, queryPoint); - - INDArray trueResults = Nd4j.zeros(DataType.DOUBLE, K, ncols); - for (int j = 0; j < K; j++) { - int pt = queryPoint - K / 2 + j; - for (int i = 0; i < ncols; i++) - trueResults.putScalar(j, i, pt); - } - - VPTree tree = new VPTree(points, "euclidean", 1, false); - - List results = new ArrayList<>(); - List distances = new ArrayList<>(); - tree.search(query, K, results, distances, false); - int dimensionToSort = 0; - - INDArray sortedResults = Nd4j.zeros(DataType.DOUBLE, K, ncols); - int i = 0; - for (DataPoint p : results) { - sortedResults.putRow(i++, p.getPoint()); - } - - sortedResults = Nd4j.sort(sortedResults, dimensionToSort, true); - assertTrue(trueResults.equalsWithEps(sortedResults, 1e-5)); - - VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, K, query); - fillSearch.search(); - results = fillSearch.getResults(); - sortedResults = Nd4j.zeros(DataType.FLOAT, K, ncols); - i = 0; - for (DataPoint p : results) - sortedResults.putRow(i++, p.getPoint()); - INDArray[] sortedWithIndices = Nd4j.sortWithIndices(sortedResults, dimensionToSort, true);; - sortedResults = sortedWithIndices[1]; - assertEquals(trueResults.sumNumber().doubleValue(), sortedResults.sumNumber().doubleValue(), 1e-5); - } - -} diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml deleted file mode 100644 index b95ab2c73..000000000 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/pom.xml +++ /dev/null @@ -1,54 +0,0 @@ - - - - - - 4.0.0 - - - org.deeplearning4j - deeplearning4j-parent - 1.0.0-SNAPSHOT - - - deeplearning4j-nearestneighbors-parent - pom - - deeplearning4j-nearestneighbors-parent - - - deeplearning4j-nearestneighbor-server - nearestneighbor-core - deeplearning4j-nearestneighbors-client - deeplearning4j-nearestneighbors-model - - - - - test-nd4j-native - - - test-nd4j-cuda-11.0 - - - diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index acd187417..c17e67df1 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -36,7 +36,6 @@ pom DeepLearning4j - http://deeplearning4j.org/ DeepLearning for java @@ -59,7 +58,6 @@ deeplearning4j-modelimport deeplearning4j-modelexport-solr deeplearning4j-zoo - deeplearning4j-nearestneighbors-parent deeplearning4j-data deeplearning4j-manifold dl4j-integration-tests