Remove more unused modules

master
agibsonccc 2021-03-06 08:43:58 +09:00
parent fa8537f0c7
commit ee06fdd16f
110 changed files with 0 additions and 14983 deletions

View File

@ -1,64 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-spark-inference-client</artifactId>
<name>datavec-spark-inference-client</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-server_2.11</artifactId>
<version>1.0.0-SNAPSHOT</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-model</artifactId>
<version>${project.parent.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,292 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.client;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.http.exceptions.UnirestException;
import lombok.AllArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.model.*;
import org.datavec.spark.inference.model.service.DataVecTransformService;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.IOException;
@AllArgsConstructor
@Slf4j
public class DataVecTransformClient implements DataVecTransformService {
private String url;
static {
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
});
}
/**
* @param transformProcess
*/
@Override
public void setCSVTransformProcess(TransformProcess transformProcess) {
try {
String s = transformProcess.toJson();
Unirest.post(url + "/transformprocess").header("accept", "application/json")
.header("Content-Type", "application/json").body(s).asJson();
} catch (UnirestException e) {
log.error("Error in setCSVTransformProcess()", e);
}
}
@Override
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
/**
* @return
*/
@Override
public TransformProcess getCSVTransformProcess() {
try {
String s = Unirest.get(url + "/transformprocess").header("accept", "application/json")
.header("Content-Type", "application/json").asString().getBody();
return TransformProcess.fromJson(s);
} catch (UnirestException e) {
log.error("Error in getCSVTransformProcess()",e);
}
return null;
}
@Override
public ImageTransformProcess getImageTransformProcess() {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
/**
* @param transform
* @return
*/
@Override
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
try {
SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
.header("accept", "application/json")
.header("Content-Type", "application/json")
.body(transform).asObject(SingleCSVRecord.class).getBody();
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformIncremental(SingleCSVRecord)",e);
}
return null;
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
try {
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"TRUE")
.body(batchCSVRecord)
.asObject(SequenceBatchCSVRecord.class)
.getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("",e);
}
return null;
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
try {
BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"FALSE")
.body(batchCSVRecord)
.asObject(BatchCSVRecord.class)
.getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transform(BatchCSVRecord)", e);
}
return null;
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
try {
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
.header("Content-Type", "application/json").body(batchCSVRecord)
.asObject(Base64NDArrayBody.class).getBody();
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformArray(BatchCSVRecord)",e);
}
return null;
}
/**
* @param singleCsvRecord
* @return
*/
@Override
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
try {
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
return array;
} catch (UnirestException e) {
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
}
return null;
}
@Override
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
/**
* @param singleCsvRecord
* @return
*/
@Override
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
try {
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
.header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"true")
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
return array;
} catch (UnirestException e) {
log.error("Error in transformSequenceArrayIncremental",e);
}
return null;
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
try {
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"true")
.body(batchCSVRecord)
.asObject(Base64NDArrayBody.class).getBody();
return batchArray1;
} catch (UnirestException e) {
log.error("Error in transformSequenceArray",e);
}
return null;
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
try {
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform")
.header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"true")
.body(batchCSVRecord)
.asObject(SequenceBatchCSVRecord.class).getBody();
return batchCSVRecord1;
} catch (UnirestException e) {
log.error("Error in transformSequence");
}
return null;
}
/**
* @param transform
* @return
*/
@Override
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
try {
SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
.header("accept", "application/json")
.header("Content-Type", "application/json")
.header(SEQUENCE_OR_NOT_HEADER,"true")
.body(transform).asObject(SequenceBatchCSVRecord.class).getBody();
return singleCsvRecord;
} catch (UnirestException e) {
log.error("Error in transformSequenceIncremental");
}
return null;
}
}

View File

@ -1,45 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.transform.client;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.transform.client";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,139 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.transform.client;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.spark.inference.server.CSVSparkTransformServer;
import org.datavec.spark.inference.client.DataVecTransformClient;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
import static org.junit.Assume.assumeNotNull;
public class DataVecTransformClientTest {
private static CSVSparkTransformServer server;
private static int port = getAvailablePort();
private static DataVecTransformClient client;
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
private static TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
@BeforeClass
public static void beforeClass() throws Exception {
FileUtils.write(fileSave, transformProcess.toJson());
fileSave.deleteOnExit();
server = new CSVSparkTransformServer();
server.runMain(new String[] {"-dp", String.valueOf(port)});
client = new DataVecTransformClient("http://localhost:" + port);
client.setCSVTransformProcess(transformProcess);
}
@AfterClass
public static void afterClass() throws Exception {
server.stop();
}
@Test
public void testSequenceClient() {
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
List<BatchCSVRecord> batchCSVRecordList = new ArrayList<>();
for(int i = 0; i < 5; i++) {
batchCSVRecordList.add(batchCSVRecord);
}
sequenceBatchCSVRecord.add(batchCSVRecordList);
SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord);
assumeNotNull(sequenceBatchCSVRecord1);
Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord);
assumeNotNull(array);
Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord);
assumeNotNull(incrementalBody);
Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord);
assumeNotNull(incrementalSequenceBody);
}
@Test
public void testRecord() throws Exception {
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord);
assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size());
Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord);
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
assumeNotNull(arr);
}
@Test
public void testBatchRecord() throws Exception {
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord);
assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size());
Base64NDArrayBody body = client.transformArray(batchCSVRecord);
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
assumeNotNull(arr);
}
public static int getAvailablePort() {
try {
ServerSocket socket = new ServerSocket(0);
try {
return socket.getLocalPort();
} finally {
socket.close();
}
} catch (IOException e) {
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
}
}
}

View File

@ -1,6 +0,0 @@
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
play.modules.enabled += io.skymind.skil.service.PredictionModule
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
play.server.pidfile.path=/tmp/RUNNING_PID
play.server.http.port = 9600

View File

@ -1,63 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-spark-inference-model</artifactId>
<name>datavec-spark-inference-model</name>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-local</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,286 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.util.ndarray.RecordConverter;
import org.datavec.api.writable.Writable;
import org.datavec.arrow.ArrowConverter;
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
import org.datavec.local.transforms.LocalTransformExecutor;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import static org.datavec.arrow.ArrowConverter.*;
import static org.datavec.local.transforms.LocalTransformExecutor.execute;
import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence;
@AllArgsConstructor
@Slf4j
public class CSVSparkTransform {
@Getter
private TransformProcess transformProcess;
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
/**
* Convert a raw record via
* the {@link TransformProcess}
* to a base 64ed ndarray
* @param batch the record to convert
* @return teh base 64ed ndarray
* @throws IOException
*/
public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException {
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
bufferAllocator,transformProcess.getInitialSchema(),
batch.getRecordsAsString()),
transformProcess.getInitialSchema()),transformProcess);
ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted;
INDArray convert = ArrowConverter.toArray(arrowRecordBatch);
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
}
/**
* Convert a raw record via
* the {@link TransformProcess}
* to a base 64ed ndarray
* @param record the record to convert
* @return the base 64ed ndarray
* @throws IOException
*/
public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException {
List<Writable> record2 = toArrowWritablesSingle(
toArrowColumnsStringSingle(bufferAllocator,
transformProcess.getInitialSchema(),record.getValues()),
transformProcess.getInitialSchema());
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
}
/**
* Runs the transform process
* @param batch the record to transform
* @return the transformed record
*/
public BatchCSVRecord transform(BatchCSVRecord batch) {
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
bufferAllocator,transformProcess.getInitialSchema(),
batch.getRecordsAsString()),
transformProcess.getInitialSchema()),transformProcess);
int numCols = converted.get(0).size();
for (int row = 0; row < converted.size(); row++) {
String[] values = new String[numCols];
for (int i = 0; i < values.length; i++)
values[i] = converted.get(row).get(i).toString();
batchCSVRecord.add(new SingleCSVRecord(values));
}
return batchCSVRecord;
}
/**
* Runs the transform process
* @param record the record to transform
* @return the transformed record
*/
public SingleCSVRecord transform(SingleCSVRecord record) {
List<Writable> record2 = toArrowWritablesSingle(
toArrowColumnsStringSingle(bufferAllocator,
transformProcess.getInitialSchema(),record.getValues()),
transformProcess.getInitialSchema());
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
String[] values = new String[finalRecord.size()];
for (int i = 0; i < values.length; i++)
values[i] = finalRecord.get(i).toString();
return new SingleCSVRecord(values);
}
/**
*
* @param transform
* @return
*/
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
/**
* Sequence schema?
*/
List<List<List<Writable>>> converted = executeToSequence(
toArrowWritables(toArrowColumnsStringTimeSeries(
bufferAllocator, transformProcess.getInitialSchema(),
Arrays.asList(transform.getRecordsAsString())),
transformProcess.getInitialSchema()), transformProcess);
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
for (int i = 0; i < converted.size(); i++) {
BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i));
batchCSVRecord.add(Arrays.asList(batchCSVRecord1));
}
return batchCSVRecord;
}
/**
*
* @param batchCSVRecordSequence
* @return
*/
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) {
List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString();
boolean allSameLength = true;
Integer length = null;
for(List<List<String>> record : recordsAsString) {
if(length == null) {
length = record.size();
}
else if(record.size() != length) {
allSameLength = false;
}
}
if(allSameLength) {
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString);
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,
transformProcess.getInitialSchema(),
recordsAsString.get(0).get(0).size());
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
return SequenceBatchCSVRecord.fromWritables(transformed);
}
else {
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
return SequenceBatchCSVRecord.fromWritables(transformed);
}
}
/**
* TODO: optimize
* @param batchCSVRecordSequence
* @return
*/
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) {
List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString();
boolean allSameLength = true;
Integer length = null;
for(List<List<String>> record : strings) {
if(length == null) {
length = record.size();
}
else if(record.size() != length) {
allSameLength = false;
}
}
if(allSameLength) {
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
else {
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
}
/**
*
* @param singleCsvRecord
* @return
*/
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
List<List<List<Writable>>> converted = executeToSequence(toArrowWritables(toArrowColumnsString(
bufferAllocator,transformProcess.getInitialSchema(),
singleCsvRecord.getRecordsAsString()),
transformProcess.getInitialSchema()),transformProcess);
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted;
INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch);
try {
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
} catch (IOException e) {
log.error("",e);
}
return null;
}
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString();
boolean allSameLength = true;
Integer length = null;
for(List<List<String>> record : strings) {
if(length == null) {
length = record.size();
}
else if(record.size() != length) {
allSameLength = false;
}
}
if(allSameLength) {
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
return SequenceBatchCSVRecord.fromWritables(transformed);
}
else {
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
return SequenceBatchCSVRecord.fromWritables(transformed);
}
}
}

View File

@ -1,64 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
import org.datavec.image.data.ImageWritable;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchImageRecord;
import org.datavec.spark.inference.model.model.SingleImageRecord;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
@AllArgsConstructor
public class ImageSparkTransform {
@Getter
private ImageTransformProcess imageTransformProcess;
public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException {
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri());
INDArray finalRecord = imageTransformProcess.executeArray(record2);
return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord));
}
public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException {
List<INDArray> records = new ArrayList<>();
for (SingleImageRecord imgRecord : batch.getRecords()) {
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri());
INDArray finalRecord = imageTransformProcess.executeArray(record2);
records.add(finalRecord);
}
INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()]));
return new Base64NDArrayBody(Nd4jBase64.base64String(array));
}
}

View File

@ -1,32 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class Base64NDArrayBody {
private String ndarray;
}

View File

@ -1,104 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@Data
@AllArgsConstructor
@Builder
@NoArgsConstructor
public class BatchCSVRecord implements Serializable {
private List<SingleCSVRecord> records;
/**
* Get the records as a list of strings
* (basically the underlying values for
* {@link SingleCSVRecord})
* @return
*/
public List<List<String>> getRecordsAsString() {
if(records == null)
records = new ArrayList<>();
List<List<String>> ret = new ArrayList<>();
for(SingleCSVRecord csvRecord : records) {
ret.add(csvRecord.getValues());
}
return ret;
}
/**
* Create a batch csv record
* from a list of writables.
* @param batch
* @return
*/
public static BatchCSVRecord fromWritables(List<List<Writable>> batch) {
List <SingleCSVRecord> records = new ArrayList<>(batch.size());
for(List<Writable> list : batch) {
List<String> add = new ArrayList<>(list.size());
for(Writable writable : list) {
add.add(writable.toString());
}
records.add(new SingleCSVRecord(add));
}
return BatchCSVRecord.builder().records(records).build();
}
/**
* Add a record
* @param record
*/
public void add(SingleCSVRecord record) {
if (records == null)
records = new ArrayList<>();
records.add(record);
}
/**
* Return a batch record based on a dataset
* @param dataSet the dataset to get the batch record for
* @return the batch record
*/
public static BatchCSVRecord fromDataSet(DataSet dataSet) {
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < dataSet.numExamples(); i++) {
batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i)));
}
return batchCSVRecord;
}
}

View File

@ -1,50 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.net.URI;
import java.util.ArrayList;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class BatchImageRecord {
private List<SingleImageRecord> records;
/**
* Add a record
* @param record
*/
public void add(SingleImageRecord record) {
if (records == null)
records = new ArrayList<>();
records.add(record);
}
public void add(URI uri) {
this.add(new SingleImageRecord(uri));
}
}

View File

@ -1,106 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.dataset.MultiDataSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@Data
@AllArgsConstructor
@Builder
@NoArgsConstructor
public class SequenceBatchCSVRecord implements Serializable {
private List<List<BatchCSVRecord>> records;
/**
* Add a record
* @param record
*/
public void add(List<BatchCSVRecord> record) {
if (records == null)
records = new ArrayList<>();
records.add(record);
}
/**
* Get the records as a list of strings directly
* (this basically "unpacks" the objects)
* @return
*/
public List<List<List<String>>> getRecordsAsString() {
if(records == null)
Collections.emptyList();
List<List<List<String>>> ret = new ArrayList<>(records.size());
for(List<BatchCSVRecord> record : records) {
List<List<String>> add = new ArrayList<>();
for(BatchCSVRecord batchCSVRecord : record) {
for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) {
add.add(singleCSVRecord.getValues());
}
}
ret.add(add);
}
return ret;
}
/**
* Convert a writables time series to a sequence batch
* @param input
* @return
*/
public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) {
SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord();
for(int i = 0; i < input.size(); i++) {
ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i))));
}
return ret;
}
/**
* Return a batch record based on a dataset
* @param dataSet the dataset to get the batch record for
* @return the batch record
*/
public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) {
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
for (int i = 0; i < dataSet.numFeatureArrays(); i++) {
batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i)))));
}
return batchCSVRecord;
}
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class SingleCSVRecord implements Serializable {
private List<String> values;
/**
* Create from an array of values uses list internally)
* @param values
*/
public SingleCSVRecord(String...values) {
this.values = Arrays.asList(values);
}
/**
* Instantiate a csv record from a vector
* given either an input dataset and a
* one hot matrix, the index will be appended to
* the end of the record, or for regression
* it will append all values in the labels
* @param row the input vectors
* @return the record from this {@link DataSet}
*/
public static SingleCSVRecord fromRow(DataSet row) {
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
//classification
SingleCSVRecord record;
int idx = 0;
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
String[] values = new String[row.getFeatures().columns() + 1];
for (int i = 0; i < row.getFeatures().length(); i++) {
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
}
int maxIdx = 0;
for (int i = 0; i < row.getLabels().length(); i++) {
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
maxIdx = i;
}
}
values[idx++] = String.valueOf(maxIdx);
record = new SingleCSVRecord(values);
}
//regression (any number of values)
else {
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
for (int i = 0; i < row.getFeatures().length(); i++) {
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
}
for (int i = 0; i < row.getLabels().length(); i++) {
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
}
record = new SingleCSVRecord(values);
}
return record;
}
}

View File

@ -1,34 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.net.URI;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class SingleImageRecord {
private URI uri;
}

View File

@ -1,131 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.model.service;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.model.*;
import java.io.IOException;
public interface DataVecTransformService {
String SEQUENCE_OR_NOT_HEADER = "Sequence";
/**
*
* @param transformProcess
*/
void setCSVTransformProcess(TransformProcess transformProcess);
/**
*
* @param imageTransformProcess
*/
void setImageTransformProcess(ImageTransformProcess imageTransformProcess);
/**
*
* @return
*/
TransformProcess getCSVTransformProcess();
/**
*
* @return
*/
ImageTransformProcess getImageTransformProcess();
/**
*
* @param singleCsvRecord
* @return
*/
SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord);
SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord);
/**
*
* @param batchCSVRecord
* @return
*/
BatchCSVRecord transform(BatchCSVRecord batchCSVRecord);
/**
*
* @param batchCSVRecord
* @return
*/
Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord);
/**
*
* @param singleCsvRecord
* @return
*/
Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord);
/**
*
* @param singleImageRecord
* @return
* @throws IOException
*/
Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException;
/**
*
* @param batchImageRecord
* @return
* @throws IOException
*/
Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException;
/**
*
* @param singleCsvRecord
* @return
*/
Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
/**
*
* @param batchCSVRecord
* @return
*/
Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord);
/**
*
* @param batchCSVRecord
* @return
*/
SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord);
/**
*
* @param transform
* @return
*/
SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform);
}

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.spark.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,40 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
public class BatchCSVRecordTest {
@Test
public void testBatchRecordCreationFromDataSet() {
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet);
assertEquals(2, batchCSVRecord.getRecords().size());
}
}

View File

@ -1,212 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.integer.BaseIntegerTransform;
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;
import org.datavec.spark.inference.model.CSVSparkTransform;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
import java.util.*;
import static org.junit.Assert.*;
public class CSVSparkTransformTest {
@Test
public void testTransformer() throws Exception {
List<Writable> input = new ArrayList<>();
input.add(new DoubleWritable(1.0));
input.add(new DoubleWritable(2.0));
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
List<Writable> output = new ArrayList<>();
output.add(new Text("1.0"));
output.add(new Text("2.0"));
TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
assertTrue(fromBase64.isVector());
// System.out.println("Base 64ed array " + fromBase64);
}
@Test
public void testTransformerBatch() throws Exception {
List<Writable> input = new ArrayList<>();
input.add(new DoubleWritable(1.0));
input.add(new DoubleWritable(2.0));
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
List<Writable> output = new ArrayList<>();
output.add(new Text("1.0"));
output.add(new Text("2.0"));
TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < 3; i++)
batchCSVRecord.add(record);
//data type is string, unable to convert
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
/* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
assertTrue(fromBase64.isMatrix());
System.out.println("Base 64ed array " + fromBase64); */
}
@Test
public void testSingleBatchSequence() throws Exception {
List<Writable> input = new ArrayList<>();
input.add(new DoubleWritable(1.0));
input.add(new DoubleWritable(2.0));
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
List<Writable> output = new ArrayList<>();
output.add(new Text("1.0"));
output.add(new Text("2.0"));
TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < 3; i++)
batchCSVRecord.add(record);
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray());
//ensure accumulation
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape());
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
assertNotNull(transformed.getRecords());
// System.out.println(transformed);
}
@Test
public void testSpecificSequence() throws Exception {
final Schema schema = new Schema.Builder()
.addColumnsString("action")
.build();
final TransformProcess transformProcess = new TransformProcess.Builder(schema)
.removeAllColumnsExceptFor("action")
.transform(new ConverToLowercase("action"))
.convertToSequence()
.transform(new TextToCharacterIndexTransform("action", "action_sequence",
defaultCharIndex(), false))
.integerToOneHot("action_sequence",0,29)
.build();
final String[] data1 = new String[] { "test1" };
final String[] data2 = new String[] { "test2" };
final BatchCSVRecord batchCsvRecord = new BatchCSVRecord(
Arrays.asList(
new SingleCSVRecord(data1),
new SingleCSVRecord(data2)));
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
transform.transformSequenceIncremental(batchCsvRecord);
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
}
private static Map<Character,Integer> defaultCharIndex() {
Map<Character,Integer> ret = new TreeMap<>();
ret.put('a',0);
ret.put('b',1);
ret.put('c',2);
ret.put('d',3);
ret.put('e',4);
ret.put('f',5);
ret.put('g',6);
ret.put('h',7);
ret.put('i',8);
ret.put('j',9);
ret.put('k',10);
ret.put('l',11);
ret.put('m',12);
ret.put('n',13);
ret.put('o',14);
ret.put('p',15);
ret.put('q',16);
ret.put('r',17);
ret.put('s',18);
ret.put('t',19);
ret.put('u',20);
ret.put('v',21);
ret.put('w',22);
ret.put('x',23);
ret.put('y',24);
ret.put('z',25);
ret.put('/',26);
ret.put(' ',27);
ret.put('(',28);
ret.put(')',29);
return ret;
}
public static class ConverToLowercase extends BaseIntegerTransform {
public ConverToLowercase(String column) {
super(column);
}
public Text map(Writable writable) {
return new Text(writable.toString().toLowerCase());
}
public Object map(Object input) {
return new Text(input.toString().toLowerCase());
}
}
}

View File

@ -1,86 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.ImageSparkTransform;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchImageRecord;
import org.datavec.spark.inference.model.model.SingleImageRecord;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.File;
import static org.junit.Assert.assertEquals;
public class ImageSparkTransformTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testSingleImageSparkTransform() throws Exception {
int seed = 12345;
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI());
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
.scaleImageTransform(10).cropImageTransform(5).build();
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
// System.out.println("Base 64ed array " + fromBase64);
assertEquals(1, fromBase64.size(0));
}
@Test
public void testBatchImageSparkTransform() throws Exception {
int seed = 12345;
File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile();
File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile();
BatchImageRecord batch = new BatchImageRecord();
batch.add(f0.toURI());
batch.add(f1.toURI());
batch.add(f2.toURI());
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
.scaleImageTransform(10).cropImageTransform(5).build();
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
Base64NDArrayBody body = imgSparkTransform.toArray(batch);
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
// System.out.println("Base 64ed array " + fromBase64);
assertEquals(3, fromBase64.size(0));
}
}

View File

@ -1,60 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.junit.Test;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
public class SingleCSVRecordTest {
@Test(expected = IllegalArgumentException.class)
public void testVectorAssertion() {
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1));
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet);
fail(singleCsvRecord.toString() + " should have thrown an exception");
}
@Test
public void testVectorOneHotLabel() {
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}}));
//assert
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
assertEquals(3, singleCsvRecord.getValues().size());
}
@Test
public void testVectorRegression() {
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
//assert
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
assertEquals(4, singleCsvRecord.getValues().size());
}
}

View File

@ -1,47 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import org.datavec.spark.inference.model.model.SingleImageRecord;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.common.io.ClassPathResource;
import java.io.File;
public class SingleImageRecordTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testImageRecord() throws Exception {
File f = testDir.newFolder();
new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f);
File f0 = new File(f, "class0/0.jpg");
File f1 = new File(f, "/class1/A.jpg");
SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI());
// need jackson test?
}
}

View File

@ -1,154 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-spark-inference-server_2.11</artifactId>
<name>datavec-spark-inference-server</name>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark-inference-model</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-spark_2.11</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
</dependency>
<dependency>
<groupId>org.hibernate</groupId>
<artifactId>hibernate-validator</artifactId>
<version>${hibernate.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-reflect</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-java_2.11</artifactId>
<version>${playframework.version}</version>
<exclusions>
<exclusion>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
</exclusion>
<exclusion>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>net.jodah</groupId>
<artifactId>typetools</artifactId>
<version>${jodah.typetools.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-json_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.play</groupId>
<artifactId>play-netty-server_2.11</artifactId>
<version>${playframework.version}</version>
</dependency>
<dependency>
<groupId>com.typesafe.akka</groupId>
<artifactId>akka-cluster_2.11</artifactId>
<version>2.5.23</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>${jcommander.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-core_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,352 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.ParameterException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.CSVSparkTransform;
import org.datavec.spark.inference.model.model.*;
import play.BuiltInComponents;
import play.Mode;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File;
import java.io.IOException;
import java.util.Base64;
import java.util.Random;
import static play.mvc.Results.*;
@Slf4j
@Data
public class CSVSparkTransformServer extends SparkTransformServer {
private CSVSparkTransform transform;
public void runMain(String[] args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
jcmdr.usage();
if (jsonPath == null)
System.err.println("Json path parameter is missing.");
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath));
TransformProcess transformProcess = TransformProcess.fromJson(json);
transform = new CSVSparkTransform(transformProcess);
} else {
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
+ "to /transformprocess");
}
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
String crypto = System.getProperty("play.crypto.secret");
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
byte[] newCrypto = new byte[1024];
new Random().nextBytes(newCrypto);
String base64 = Base64.getEncoder().encodeToString(newCrypto);
System.setProperty("play.crypto.secret", base64);
}
server = Server.forRouter(Mode.PROD, port, this::createRouter);
}
protected Router createRouter(BuiltInComponents b){
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
routingDsl.GET("/transformprocess").routingTo(req -> {
try {
if (transform == null)
return badRequest();
return ok(transform.getTransformProcess().toJson()).as(contentType);
} catch (Exception e) {
log.error("Error in GET /transformprocess",e);
return internalServerError(e.getMessage());
}
});
routingDsl.POST("/transformprocess").routingTo(req -> {
try {
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
setCSVTransformProcess(transformProcess);
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) {
log.error("Error in POST /transformprocess",e);
return internalServerError(e.getMessage());
}
});
routingDsl.POST("/transformincremental").routingTo(req -> {
if (isSequence(req)) {
try {
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformincremental", e);
return internalServerError(e.getMessage());
}
} else {
try {
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformincremental", e);
return internalServerError(e.getMessage());
}
}
});
routingDsl.POST("/transform").routingTo(req -> {
if (isSequence(req)) {
try {
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
if (batch == null)
return badRequest();
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
} catch (Exception e) {
log.error("Error in /transform", e);
return internalServerError(e.getMessage());
}
} else {
try {
BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
BatchCSVRecord batch = transform(input);
if (batch == null)
return badRequest();
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
} catch (Exception e) {
log.error("Error in /transform", e);
return internalServerError(e.getMessage());
}
}
});
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
if (isSequence(req)) {
try {
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformincrementalarray", e);
return internalServerError(e.getMessage());
}
} else {
try {
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformincrementalarray", e);
return internalServerError(e.getMessage());
}
}
});
routingDsl.POST("/transformarray").routingTo(req -> {
if (isSequence(req)) {
try {
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
if (batchCSVRecord == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformarray", e);
return internalServerError(e.getMessage());
}
} else {
try {
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
if (batchCSVRecord == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
} catch (Exception e) {
log.error("Error in /transformarray", e);
return internalServerError(e.getMessage());
}
}
});
return routingDsl.build();
}
public static void main(String[] args) throws Exception {
new CSVSparkTransformServer().runMain(args);
}
/**
* @param transformProcess
*/
@Override
public void setCSVTransformProcess(TransformProcess transformProcess) {
this.transform = new CSVSparkTransform(transformProcess);
}
@Override
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass());
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
/**
* @return
*/
@Override
public TransformProcess getCSVTransformProcess() {
return transform.getTransformProcess();
}
@Override
public ImageTransformProcess getImageTransformProcess() {
log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass());
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
/**
*
*/
/**
* @param transform
* @return
*/
@Override
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
return this.transform.transformSequenceIncremental(transform);
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
return transform.transformSequence(batchCSVRecord);
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
return this.transform.transformSequenceArray(batchCSVRecord);
}
/**
* @param singleCsvRecord
* @return
*/
@Override
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
return this.transform.transformSequenceArrayIncremental(singleCsvRecord);
}
/**
* @param transform
* @return
*/
@Override
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
return this.transform.transform(transform);
}
@Override
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
return this.transform.transform(batchCSVRecord);
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
return transform.transform(batchCSVRecord);
}
/**
* @param batchCSVRecord
* @return
*/
@Override
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
try {
return this.transform.toArray(batchCSVRecord);
} catch (IOException e) {
log.error("Error in transformArray",e);
throw new IllegalStateException("Transform array shouldn't throw exception");
}
}
/**
* @param singleCsvRecord
* @return
*/
@Override
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
try {
return this.transform.toArray(singleCsvRecord);
} catch (IOException e) {
log.error("Error in transformArrayIncremental",e);
throw new IllegalStateException("Transform array shouldn't throw exception");
}
}
@Override
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass());
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass());
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
}

View File

@ -1,261 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.ParameterException;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.model.ImageSparkTransform;
import org.datavec.spark.inference.model.model.*;
import play.BuiltInComponents;
import play.Mode;
import play.libs.Files;
import play.mvc.Http;
import play.routing.Router;
import play.routing.RoutingDsl;
import play.server.Server;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import static play.mvc.Results.*;
@Slf4j
@Data
public class ImageSparkTransformServer extends SparkTransformServer {
private ImageSparkTransform transform;
public void runMain(String[] args) throws Exception {
JCommander jcmdr = new JCommander(this);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
//User provides invalid input -> print the usage info
jcmdr.usage();
if (jsonPath == null)
System.err.println("Json path parameter is missing.");
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
if (jsonPath != null) {
String json = FileUtils.readFileToString(new File(jsonPath));
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
transform = new ImageSparkTransform(transformProcess);
} else {
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
+ "to /transformprocess");
}
server = Server.forRouter(Mode.PROD, port, this::createRouter);
}
protected Router createRouter(BuiltInComponents builtInComponents){
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
routingDsl.GET("/transformprocess").routingTo(req -> {
try {
if (transform == null)
return badRequest();
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
routingDsl.POST("/transformprocess").routingTo(req -> {
try {
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
setImageTransformProcess(transformProcess);
log.info("Transform process initialized");
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
try {
SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
if (record == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
routingDsl.POST("/transformincrementalimage").routingTo(req -> {
try {
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
if (files.isEmpty() || files.get(0).getRef() == null ) {
return badRequest();
}
File file = files.get(0).getRef().path().toFile();
SingleImageRecord record = new SingleImageRecord(file.toURI());
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
routingDsl.POST("/transformarray").routingTo(req -> {
try {
BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
if (batch == null)
return badRequest();
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
routingDsl.POST("/transformimage").routingTo(req -> {
try {
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
if (files.size() == 0) {
return badRequest();
}
List<SingleImageRecord> records = new ArrayList<>();
for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
Files.TemporaryFile file = filePart.getRef();
if (file != null) {
SingleImageRecord record = new SingleImageRecord(file.path().toUri());
records.add(record);
}
}
BatchImageRecord batch = new BatchImageRecord(records);
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
} catch (Exception e) {
log.error("",e);
return internalServerError();
}
});
return routingDsl.build();
}
@Override
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
throw new UnsupportedOperationException();
}
@Override
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
throw new UnsupportedOperationException();
}
@Override
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
throw new UnsupportedOperationException();
}
@Override
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
throw new UnsupportedOperationException();
}
@Override
public void setCSVTransformProcess(TransformProcess transformProcess) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
this.transform = new ImageSparkTransform(imageTransformProcess);
}
@Override
public TransformProcess getCSVTransformProcess() {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public ImageTransformProcess getImageTransformProcess() {
return transform.getImageTransformProcess();
}
@Override
public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
}
@Override
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException {
return transform.toArray(record);
}
@Override
public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException {
return transform.toArray(batch);
}
public static void main(String[] args) throws Exception {
new ImageSparkTransformServer().runMain(args);
}
}

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.server;
import com.beust.jcommander.Parameter;
import com.fasterxml.jackson.databind.JsonNode;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.service.DataVecTransformService;
import org.nd4j.shade.jackson.databind.ObjectMapper;
import play.mvc.Http;
import play.server.Server;
public abstract class SparkTransformServer implements DataVecTransformService {
@Parameter(names = {"-j", "--jsonPath"}, arity = 1)
protected String jsonPath = null;
@Parameter(names = {"-dp", "--dataVecPort"}, arity = 1)
protected int port = 9000;
@Parameter(names = {"-dt", "--dataType"}, arity = 1)
private TransformDataType transformDataType = null;
protected Server server;
protected static ObjectMapper objectMapper = new ObjectMapper();
protected static String contentType = "application/json";
public abstract void runMain(String[] args) throws Exception;
/**
* Stop the server
*/
public void stop() {
if (server != null)
server.stop();
}
protected boolean isSequence(Http.Request request) {
return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
&& request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
}
protected String getJsonText(Http.Request request) {
JsonNode tryJson = request.body().asJson();
if (tryJson != null)
return tryJson.toString();
else
return request.body().asText();
}
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
}

View File

@ -1,76 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.server;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import java.io.InvalidClassException;
import java.util.Arrays;
import java.util.List;
@Data
@Slf4j
public class SparkTransformServerChooser {
private SparkTransformServer sparkTransformServer = null;
private TransformDataType transformDataType = null;
public void runMain(String[] args) throws Exception {
int pos = getMatchingPosition(args, "-dt", "--dataType");
if (pos == -1) {
log.error("no valid options");
log.error("-dt, --dataType Options: [CSV, IMAGE]");
throw new Exception("no valid options");
} else {
transformDataType = TransformDataType.valueOf(args[pos + 1]);
}
switch (transformDataType) {
case CSV:
sparkTransformServer = new CSVSparkTransformServer();
break;
case IMAGE:
sparkTransformServer = new ImageSparkTransformServer();
break;
default:
throw new InvalidClassException("no matching SparkTransform class");
}
sparkTransformServer.runMain(args);
}
private int getMatchingPosition(String[] args, String... options) {
List optionList = Arrays.asList(options);
for (int i = 0; i < args.length; i++) {
if (optionList.contains(args[i])) {
return i;
}
}
return -1;
}
public static void main(String[] args) throws Exception {
new SparkTransformServerChooser().runMain(args);
}
}

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.inference.server;
public enum TransformDataType {
CSV, IMAGE,
}

View File

@ -1,350 +0,0 @@
# This is the main configuration file for the application.
# https://www.playframework.com/documentation/latest/ConfigFile
# ~~~~~
# Play uses HOCON as its configuration file format. HOCON has a number
# of advantages over other config formats, but there are two things that
# can be used when modifying settings.
#
# You can include other configuration files in this main application.conf file:
#include "extra-config.conf"
#
# You can declare variables and substitute for them:
#mykey = ${some.value}
#
# And if an environment variable exists when there is no other subsitution, then
# HOCON will fall back to substituting environment variable:
#mykey = ${JAVA_HOME}
## Akka
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
# ~~~~~
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
# other streaming HTTP responses.
akka {
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
# configuration at INFO level, including defaults and overrides, so it s worth
# putting at the very top.
#
# Put the following in your conf/logback.xml file:
#
# <logger name="akka.actor" level="INFO" />
#
# And then uncomment this line to debug the configuration.
#
#log-config-on-start = true
}
## Modules
# https://www.playframework.com/documentation/latest/Modules
# ~~~~~
# Control which modules are loaded when Play starts. Note that modules are
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
# for more information.
#
# You can also extend Play functionality by using one of the publically available
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
play.modules {
# By default, Play will load any class called Module that is defined
# in the root package (the "app" directory), or you can define them
# explicitly below.
# If there are any built-in modules that you want to disable, you can list them here.
#enabled += my.application.Module
# If there are any built-in modules that you want to disable, you can list them here.
#disabled += ""
}
## Internationalisation
# https://www.playframework.com/documentation/latest/JavaI18N
# https://www.playframework.com/documentation/latest/ScalaI18N
# ~~~~~
# Play comes with its own i18n settings, which allow the user's preferred language
# to map through to internal messages, or allow the language to be stored in a cookie.
play.i18n {
# The application languages
langs = [ "en" ]
# Whether the language cookie should be secure or not
#langCookieSecure = true
# Whether the HTTP only attribute of the cookie should be set to true
#langCookieHttpOnly = true
}
## Play HTTP settings
# ~~~~~
play.http {
## Router
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# Define the Router object to use for this application.
# This router will be looked up first when the application is starting up,
# so make sure this is the entry point.
# Furthermore, it's assumed your route file is named properly.
# So for an application router like `my.application.Router`,
# you may need to define a router file `conf/my.application.routes`.
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
#router = my.application.Router
## Action Creator
# https://www.playframework.com/documentation/latest/JavaActionCreator
# ~~~~~
#actionCreator = null
## ErrorHandler
# https://www.playframework.com/documentation/latest/JavaRouting
# https://www.playframework.com/documentation/latest/ScalaRouting
# ~~~~~
# If null, will attempt to load a class called ErrorHandler in the root package,
#errorHandler = null
## Filters
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
# https://www.playframework.com/documentation/latest/JavaHttpFilters
# ~~~~~
# Filters run code on every request. They can be used to perform
# common logic for all your actions, e.g. adding common headers.
# Defaults to "Filters" in the root package (aka "apps" folder)
# Alternatively you can explicitly register a class here.
#filters += my.application.Filters
## Session & Flash
# https://www.playframework.com/documentation/latest/JavaSessionFlash
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
# ~~~~~
session {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
# Sets the max-age field of the cookie to 5 minutes.
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
#maxAge = 300
# Sets the domain on the session cookie.
#domain = "example.com"
}
flash {
# Sets the cookie to be sent only over HTTPS.
#secure = true
# Sets the cookie to be accessed only by the server.
#httpOnly = true
}
}
## Netty Provider
# https://www.playframework.com/documentation/latest/SettingsNetty
# ~~~~~
play.server.netty {
# Whether the Netty wire should be logged
#log.wire = true
# If you run Play on Linux, you can use Netty's native socket transport
# for higher performance with less garbage.
#transport = "native"
}
## WS (HTTP Client)
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
# ~~~~~
# The HTTP client primarily used for REST APIs. The default client can be
# configured directly, but you can also create different client instances
# with customized settings. You must enable this by adding to build.sbt:
#
# libraryDependencies += ws // or javaWs if using java
#
play.ws {
# Sets HTTP requests not to follow 302 requests
#followRedirects = false
# Sets the maximum number of open HTTP connections for the client.
#ahc.maxConnectionsTotal = 50
## WS SSL
# https://www.playframework.com/documentation/latest/WsSSL
# ~~~~~
ssl {
# Configuring HTTPS with Play WS does not require programming. You can
# set up both trustManager and keyManager for mutual authentication, and
# turn on JSSE debugging in development with a reload.
#debug.handshake = true
#trustManager = {
# stores = [
# { type = "JKS", path = "exampletrust.jks" }
# ]
#}
}
}
## Cache
# https://www.playframework.com/documentation/latest/JavaCache
# https://www.playframework.com/documentation/latest/ScalaCache
# ~~~~~
# Play comes with an integrated cache API that can reduce the operational
# overhead of repeated requests. You must enable this by adding to build.sbt:
#
# libraryDependencies += cache
#
play.cache {
# If you want to bind several caches, you can bind the individually
#bindCaches = ["db-cache", "user-cache", "session-cache"]
}
## Filters
# https://www.playframework.com/documentation/latest/Filters
# ~~~~~
# There are a number of built-in filters that can be enabled and configured
# to give Play greater security. You must enable this by adding to build.sbt:
#
# libraryDependencies += filters
#
play.filters {
## CORS filter configuration
# https://www.playframework.com/documentation/latest/CorsFilter
# ~~~~~
# CORS is a protocol that allows web applications to make requests from the browser
# across different domains.
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
# dependencies on CORS settings.
cors {
# Filter paths by a whitelist of path prefixes
#pathPrefixes = ["/some/path", ...]
# The allowed origins. If null, all origins are allowed.
#allowedOrigins = ["http://www.example.com"]
# The allowed HTTP methods. If null, all methods are allowed
#allowedHttpMethods = ["GET", "POST"]
}
## CSRF Filter
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
# ~~~~~
# Play supports multiple methods for verifying that a request is not a CSRF request.
# The primary mechanism is a CSRF token. This token gets placed either in the query string
# or body of every form submitted, and also gets placed in the users session.
# Play then verifies that both tokens are present and match.
csrf {
# Sets the cookie to be sent only over HTTPS
#cookie.secure = true
# Defaults to CSRFErrorHandler in the root package.
#errorHandler = MyCSRFErrorHandler
}
## Security headers filter configuration
# https://www.playframework.com/documentation/latest/SecurityHeaders
# ~~~~~
# Defines security headers that prevent XSS attacks.
# If enabled, then all options are set to the below configuration by default:
headers {
# The X-Frame-Options header. If null, the header is not set.
#frameOptions = "DENY"
# The X-XSS-Protection header. If null, the header is not set.
#xssProtection = "1; mode=block"
# The X-Content-Type-Options header. If null, the header is not set.
#contentTypeOptions = "nosniff"
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
#permittedCrossDomainPolicies = "master-only"
# The Content-Security-Policy header. If null, the header is not set.
#contentSecurityPolicy = "default-src 'self'"
}
## Allowed hosts filter configuration
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
# ~~~~~
# Play provides a filter that lets you configure which hosts can access your application.
# This is useful to prevent cache poisoning attacks.
hosts {
# Allow requests to example.com, its subdomains, and localhost:9000.
#allowed = [".example.com", "localhost:9000"]
}
}
## Evolutions
# https://www.playframework.com/documentation/latest/Evolutions
# ~~~~~
# Evolutions allows database scripts to be automatically run on startup in dev mode
# for database migrations. You must enable this by adding to build.sbt:
#
# libraryDependencies += evolutions
#
play.evolutions {
# You can disable evolutions for a specific datasource if necessary
#db.default.enabled = false
}
## Database Connection Pool
# https://www.playframework.com/documentation/latest/SettingsJDBC
# ~~~~~
# Play doesn't require a JDBC database to run, but you can easily enable one.
#
# libraryDependencies += jdbc
#
play.db {
# The combination of these two settings results in "db.default" as the
# default JDBC pool:
#config = "db"
#default = "default"
# Play uses HikariCP as the default connection pool. You can override
# settings by changing the prototype:
prototype {
# Sets a fixed JDBC connection pool size of 50
#hikaricp.minimumIdle = 50
#hikaricp.maximumPoolSize = 50
}
}
## JDBC Datasource
# https://www.playframework.com/documentation/latest/JavaDatabase
# https://www.playframework.com/documentation/latest/ScalaDatabase
# ~~~~~
# Once JDBC datasource is set up, you can work with several different
# database options:
#
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
# EBean: https://playframework.com/documentation/latest/JavaEbean
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
#
db {
# You can declare as many datasources as you want.
# By convention, the default datasource is named `default`
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
default.driver = org.h2.Driver
default.url = "jdbc:h2:mem:play"
#default.username = sa
#default.password = ""
# You can expose this datasource via JNDI if needed (Useful for JPA)
default.jndiName=DefaultDS
# You can turn on SQL logging for any datasource
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
#default.logSql=true
}
jpa.default=defaultPersistenceUnit
#Increase default maximum post length - used for remote listener functionality
#Can get response 413 with larger networks without setting this
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
#parsers.text.maxLength=10M
play.http.parser.maxMemoryBuffer=10M

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.common.tests.AbstractAssertTestsClass;
import org.nd4j.common.tests.BaseND4JTest;
import java.util.*;
@Slf4j
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
@Override
protected Set<Class<?>> getExclusions() {
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
return new HashSet<>();
}
@Override
protected String getPackageName() {
return "org.datavec.spark.transform";
}
@Override
protected Class<?> getBaseClass() {
return BaseND4JTest.class;
}
}

View File

@ -1,127 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.spark.inference.server.CSVSparkTransformServer;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
import static org.junit.Assert.assertTrue;
import static org.junit.Assume.assumeNotNull;
public class CSVSparkTransformServerNoJsonTest {
private static CSVSparkTransformServer server;
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
private static TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
@BeforeClass
public static void before() throws Exception {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
server.runMain(new String[] {"-dp", "9050"});
}
@AfterClass
public static void after() throws Exception {
fileSave.delete();
server.stop();
}
@Test
public void testServer() throws Exception {
assertTrue(server.getTransform() == null);
JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(transformProcess.toJson()).asJson().getBody();
assumeNotNull(server.getTransform());
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = new SingleCSVRecord(values);
JsonNode jsonNode =
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
.header("Content-Type", "application/json").body(record).asJson().getBody();
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(SingleCSVRecord.class).getBody();
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < 3; i++)
batchCSVRecord.add(singleCsvRecord);
/* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(Base64NDArrayBody.class).getBody();
*/
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
}
}

View File

@ -1,121 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.spark.inference.server.CSVSparkTransformServer;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchCSVRecord;
import org.datavec.spark.inference.model.model.SingleCSVRecord;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
public class CSVSparkTransformServerTest {
private static CSVSparkTransformServer server;
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
private static TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
@BeforeClass
public static void before() throws Exception {
server = new CSVSparkTransformServer();
FileUtils.write(fileSave, transformProcess.toJson());
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
}
@AfterClass
public static void after() throws Exception {
fileSave.deleteOnExit();
server.stop();
}
@Test
public void testServer() throws Exception {
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = new SingleCSVRecord(values);
JsonNode jsonNode =
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
.header("Content-Type", "application/json").body(record).asJson().getBody();
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(SingleCSVRecord.class).getBody();
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < 3; i++)
batchCSVRecord.add(singleCsvRecord);
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(Base64NDArrayBody.class).getBody();
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
}
}

View File

@ -1,164 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import org.apache.commons.io.FileUtils;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.server.ImageSparkTransformServer;
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
import org.datavec.spark.inference.model.model.BatchImageRecord;
import org.datavec.spark.inference.model.model.SingleImageRecord;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
public class ImageSparkTransformServerTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
private static ImageSparkTransformServer server;
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
@BeforeClass
public static void before() throws Exception {
server = new ImageSparkTransformServer();
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
.scaleImageTransform(10).cropImageTransform(5).build();
FileUtils.write(fileSave, imgTransformProcess.toJson());
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"});
}
@AfterClass
public static void after() throws Exception {
fileSave.deleteOnExit();
server.stop();
}
@Test
public void testImageServer() throws Exception {
SingleImageRecord record =
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asJson().getBody();
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(Base64NDArrayBody.class).getBody();
BatchImageRecord batch = new BatchImageRecord();
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
JsonNode jsonNodeBatch =
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
.header("Content-Type", "application/json").body(batch).asJson().getBody();
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
.asObject(Base64NDArrayBody.class).getBody();
INDArray result = getNDArray(jsonNode);
assertEquals(1, result.size(0));
INDArray batchResult = getNDArray(jsonNodeBatch);
assertEquals(3, batchResult.size(0));
// System.out.println(array);
}
@Test
public void testImageServerMultipart() throws Exception {
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
.header("accept", "application/json")
.field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile())
.field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile())
.field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile())
.asJson().getBody();
INDArray batchResult = getNDArray(jsonNode);
assertEquals(3, batchResult.size(0));
// System.out.println(batchResult);
}
@Test
public void testImageServerSingleMultipart() throws Exception {
File f = testDir.newFolder();
File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f);
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
.header("accept", "application/json")
.field("file1", imgFile)
.asJson().getBody();
INDArray result = getNDArray(jsonNode);
assertEquals(1, result.size(0));
// System.out.println(result);
}
public INDArray getNDArray(JsonNode node) throws IOException {
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
}
}

View File

@ -1,168 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.datavec.spark.transform;
import com.mashape.unirest.http.JsonNode;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import org.apache.commons.io.FileUtils;
import org.datavec.api.transform.TransformProcess;
import org.datavec.api.transform.schema.Schema;
import org.datavec.image.transform.ImageTransformProcess;
import org.datavec.spark.inference.server.SparkTransformServerChooser;
import org.datavec.spark.inference.server.TransformDataType;
import org.datavec.spark.inference.model.model.*;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.io.ClassPathResource;
import org.nd4j.serde.base64.Nd4jBase64;
import java.io.File;
import java.io.IOException;
import java.util.UUID;
import static org.junit.Assert.assertEquals;
public class SparkTransformServerTest {
private static SparkTransformServerChooser serverChooser;
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
private static TransformProcess transformProcess =
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build();
private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json");
private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json");
@BeforeClass
public static void before() throws Exception {
serverChooser = new SparkTransformServerChooser();
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
.scaleImageTransform(10).cropImageTransform(5).build();
FileUtils.write(imageTransformFile, imgTransformProcess.toJson());
FileUtils.write(csvTransformFile, transformProcess.toJson());
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
});
}
@AfterClass
public static void after() throws Exception {
imageTransformFile.deleteOnExit();
csvTransformFile.deleteOnExit();
}
@Test
public void testImageServer() throws Exception {
serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt",
TransformDataType.IMAGE.toString()});
SingleImageRecord record =
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asJson().getBody();
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(Base64NDArrayBody.class).getBody();
BatchImageRecord batch = new BatchImageRecord();
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
JsonNode jsonNodeBatch =
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
.header("Content-Type", "application/json").body(batch).asJson().getBody();
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
.asObject(Base64NDArrayBody.class).getBody();
INDArray result = getNDArray(jsonNode);
assertEquals(1, result.size(0));
INDArray batchResult = getNDArray(jsonNodeBatch);
assertEquals(3, batchResult.size(0));
serverChooser.getSparkTransformServer().stop();
}
@Test
public void testCSVServer() throws Exception {
serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt",
TransformDataType.CSV.toString()});
String[] values = new String[] {"1.0", "2.0"};
SingleCSVRecord record = new SingleCSVRecord(values);
JsonNode jsonNode =
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
.header("Content-Type", "application/json").body(record).asJson().getBody();
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(SingleCSVRecord.class).getBody();
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
for (int i = 0; i < 3; i++)
batchCSVRecord.add(singleCsvRecord);
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
.asObject(Base64NDArrayBody.class).getBody();
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
.header("accept", "application/json").header("Content-Type", "application/json")
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
serverChooser.getSparkTransformServer().stop();
}
public INDArray getNDArray(JsonNode node) throws IOException {
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
}
}

View File

@ -1,6 +0,0 @@
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
play.modules.enabled += io.skymind.skil.service.PredictionModule
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
play.server.pidfile.path=/tmp/RUNNING_PID
play.server.http.port = 9600

View File

@ -1,68 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.datavec</groupId>
<artifactId>datavec-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>datavec-spark-inference-parent</artifactId>
<packaging>pom</packaging>
<name>datavec-spark-inference-parent</name>
<modules>
<module>datavec-spark-inference-server</module>
<module>datavec-spark-inference-client</module>
<module>datavec-spark-inference-model</module>
</modules>
<dependencyManagement>
<dependencies>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${datavec.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
</dependency>
</dependencies>
</dependencyManagement>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -45,7 +45,6 @@
<module>datavec-data</module>
<module>datavec-spark</module>
<module>datavec-local</module>
<module>datavec-spark-inference-parent</module>
<module>datavec-jdbc</module>
<module>datavec-excel</module>
<module>datavec-arrow</module>

View File

@ -1,143 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-nearestneighbor-server</name>
<properties>
<java.compile.version>1.8</java.compile.version>
</properties>
<dependencies>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-core</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>io.vertx</groupId>
<artifactId>vertx-web</artifactId>
<version>${vertx.version}</version>
</dependency>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.beust</groupId>
<artifactId>jcommander</artifactId>
<version>${jcommander.version}</version>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<argLine>-Dfile.encoding=UTF-8 -Xmx8g</argLine>
<includes>
<!-- Default setting only runs tests that start/end with "Test" -->
<include>*.java</include>
<include>**/*.java</include>
</includes>
</configuration>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>${java.compile.version}</source>
<target>${java.compile.version}</target>
</configuration>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.0</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -1,67 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.server;
import lombok.AllArgsConstructor;
import lombok.Builder;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.util.ArrayList;
import java.util.List;
@AllArgsConstructor
@Builder
public class NearestNeighbor {
private NearestNeighborRequest record;
private VPTree tree;
private INDArray points;
public List<NearestNeighborsResult> search() {
INDArray input = points.slice(record.getInputIndex());
List<NearestNeighborsResult> results = new ArrayList<>();
if (input.isVector()) {
List<DataPoint> add = new ArrayList<>();
List<Double> distances = new ArrayList<>();
tree.search(input, record.getK(), add, distances);
if (add.size() != distances.size()) {
throw new IllegalStateException(
String.format("add.size == %d != %d == distances.size",
add.size(), distances.size()));
}
for (int i=0; i<add.size(); i++) {
results.add(new NearestNeighborsResult(add.get(i).getIndex(), distances.get(i)));
}
}
return results;
}
}

View File

@ -1,278 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.server;
import com.beust.jcommander.JCommander;
import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParameterException;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.Vertx;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.handler.BodyHandler;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.exception.DL4JInvalidInputException;
import org.deeplearning4j.nearestneighbor.model.*;
import org.deeplearning4j.nn.conf.serde.JsonMappers;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.serde.binary.BinarySerde;
import java.io.File;
import java.util.*;
@Slf4j
public class NearestNeighborsServer extends AbstractVerticle {
private static class RunArgs {
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
private String ndarrayPath = null;
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
private String labelsPath = null;
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
private int port = 9000;
@Parameter(names = {"--similarityFunction"}, arity = 1)
private String similarityFunction = "euclidean";
@Parameter(names = {"--invert"}, arity = 1)
private boolean invert = false;
}
private static RunArgs instanceArgs;
private static NearestNeighborsServer instance;
public NearestNeighborsServer(){ }
public static NearestNeighborsServer getInstance(){
return instance;
}
public static void runMain(String... args) {
RunArgs r = new RunArgs();
JCommander jcmdr = new JCommander(r);
try {
jcmdr.parse(args);
} catch (ParameterException e) {
log.error("Error in NearestNeighboursServer parameters", e);
StringBuilder sb = new StringBuilder();
jcmdr.usage(sb);
log.error("Usage: {}", sb.toString());
//User provides invalid input -> print the usage info
jcmdr.usage();
if (r.ndarrayPath == null)
log.error("Json path parameter is missing (null)");
try {
Thread.sleep(500);
} catch (Exception e2) {
}
System.exit(1);
}
instanceArgs = r;
try {
Vertx vertx = Vertx.vertx();
vertx.deployVerticle(NearestNeighborsServer.class.getName());
} catch (Throwable t){
log.error("Error in NearestNeighboursServer run method",t);
}
}
@Override
public void start() throws Exception {
instance = this;
String[] pathArr = instanceArgs.ndarrayPath.split(",");
//INDArray[] pointsArr = new INDArray[pathArr.length];
// first of all we reading shapes of saved eariler files
int rows = 0;
int cols = 0;
for (int i = 0; i < pathArr.length; i++) {
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
Shape.size(shape, 1));
if (Shape.rank(shape) != 2)
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
rows += Shape.size(shape, 0);
if (cols == 0)
cols = Shape.size(shape, 1);
else if (cols != Shape.size(shape, 1))
throw new DL4JInvalidInputException(
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
}
final List<String> labels = new ArrayList<>();
if (instanceArgs.labelsPath != null) {
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
for (int i = 0; i < labelsPathArr.length; i++) {
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
}
}
if (!labels.isEmpty() && labels.size() != rows)
throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size()));
final INDArray points = Nd4j.createUninitialized(rows, cols);
int lastPosition = 0;
for (int i = 0; i < pathArr.length; i++) {
log.info("Loading chunk {} of {}", i + 1, pathArr.length);
INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i]));
points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr);
lastPosition += pointsArr.rows();
// let's ensure we don't bring too much stuff in next loop
System.gc();
}
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
//Set play secret key, if required
//http://www.playframework.com/documentation/latest/ApplicationSecret
String crypto = System.getProperty("play.crypto.secret");
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
byte[] newCrypto = new byte[1024];
new Random().nextBytes(newCrypto);
String base64 = Base64.getEncoder().encodeToString(newCrypto);
System.setProperty("play.crypto.secret", base64);
}
Router r = Router.router(vertx);
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
createRoutes(r, labels, tree, points);
vertx.createHttpServer()
.requestHandler(r)
.listen(instanceArgs.port);
}
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
r.post("/knn").handler(rc -> {
try {
String json = rc.getBodyAsJson().encode();
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
NearestNeighbor nearestNeighbor =
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(results));
return;
} catch (Throwable e) {
log.error("Error in POST /knn",e);
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
}
});
r.post("/knnnew").handler(rc -> {
try {
String json = rc.getBodyAsJson().encode();
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
if (record == null) {
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
.putHeader("content-type", "application/json")
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
return;
}
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
List<DataPoint> results;
List<Double> distances;
if (record.isForceFillK()) {
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr);
vpTreeFillSearch.search();
results = vpTreeFillSearch.getResults();
distances = vpTreeFillSearch.getDistances();
} else {
results = new ArrayList<>();
distances = new ArrayList<>();
tree.search(arr, record.getK(), results, distances);
}
if (results.size() != distances.size()) {
rc.response()
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
return;
}
List<NearestNeighborsResult> nnResult = new ArrayList<>();
for (int i=0; i<results.size(); i++) {
if (!labels.isEmpty())
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i), labels.get(results.get(i).getIndex())));
else
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i)));
}
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
String j = JsonMappers.getMapper().writeValueAsString(results2);
rc.response()
.putHeader("content-type", "application/json")
.end(j);
} catch (Throwable e) {
log.error("Error in POST /knnnew",e);
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
.end("Error parsing request - " + e.getMessage());
return;
}
});
}
/**
* Stop the server
*/
public void stop() throws Exception {
super.stop();
}
public static void main(String[] args) throws Exception {
runMain(args);
}
}

View File

@ -1,161 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.server;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.vptree.VPTree;
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.serde.binary.BinarySerde;
import java.io.File;
import java.io.IOException;
import java.net.ServerSocket;
import java.util.List;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import static org.junit.Assert.assertEquals;
public class NearestNeighborTest extends BaseDL4JTest {
@Rule
public TemporaryFolder testDir = new TemporaryFolder();
@Test
public void testNearestNeighbor() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data);
VPTree vpTree = new VPTree(arr, false);
NearestNeighborRequest request = new NearestNeighborRequest();
request.setK(2);
request.setInputIndex(0);
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
List<NearestNeighborsResult> results = nearestNeighbor.search();
assertEquals(1, results.get(0).getIndex());
assertEquals(2, results.size());
assertEquals(1.0, results.get(0).getDistance(), 1e-4);
assertEquals(4.0, results.get(1).getDistance(), 1e-4);
}
@Test
public void testNearestNeighborInverted() {
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
INDArray arr = Nd4j.create(data);
VPTree vpTree = new VPTree(arr, true);
NearestNeighborRequest request = new NearestNeighborRequest();
request.setK(2);
request.setInputIndex(0);
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
List<NearestNeighborsResult> results = nearestNeighbor.search();
assertEquals(2, results.get(0).getIndex());
assertEquals(2, results.size());
assertEquals(-4.0, results.get(0).getDistance(), 1e-4);
assertEquals(-1.0, results.get(1).getDistance(), 1e-4);
}
@Test
public void vpTreeTest() throws Exception {
INDArray matrix = Nd4j.rand(new int[] {400,10});
INDArray rowVector = matrix.getRow(70);
INDArray resultArr = Nd4j.zeros(400,1);
Executor executor = Executors.newSingleThreadExecutor();
VPTree vpTree = new VPTree(matrix);
System.out.println("Ran!");
}
public static int getAvailablePort() {
try {
ServerSocket socket = new ServerSocket(0);
try {
return socket.getLocalPort();
} finally {
socket.close();
}
} catch (IOException e) {
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
}
}
@Test
public void testServer() throws Exception {
int localPort = getAvailablePort();
Nd4j.getRandom().setSeed(7);
INDArray rand = Nd4j.randn(10, 5);
File writeToTmp = testDir.newFile();
writeToTmp.deleteOnExit();
BinarySerde.writeArrayToDisk(rand, writeToTmp);
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
String.valueOf(localPort));
Thread.sleep(3000);
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
assertEquals(5, result.getResults().size());
NearestNeighborsServer.getInstance().stop();
}
@Test
public void testFullSearch() throws Exception {
int numRows = 1000;
int numCols = 100;
int numNeighbors = 42;
INDArray points = Nd4j.rand(numRows, numCols);
VPTree tree = new VPTree(points);
INDArray query = Nd4j.rand(new int[] {1, numCols});
VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query);
fillSearch.search();
List<DataPoint> results = fillSearch.getResults();
List<Double> distances = fillSearch.getDistances();
assertEquals(numNeighbors, distances.size());
assertEquals(numNeighbors, results.size());
}
@Test
public void testDistances() {
INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}});
INDArray record = Nd4j.create(new float[][]{{7, 6}});
VPTree vpTree = new VPTree(indArray, "euclidean", false);
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record);
vpTreeFillSearch.search();
//System.out.println(vpTreeFillSearch.getResults());
System.out.println(vpTreeFillSearch.getDistances());
}
}

View File

@ -1,46 +0,0 @@
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -1,60 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-nearestneighbors-client</name>
<dependencies>
<dependency>
<groupId>com.mashape.unirest</groupId>
<artifactId>unirest-java</artifactId>
<version>${unirest.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,137 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.client;
import com.mashape.unirest.http.ObjectMapper;
import com.mashape.unirest.http.Unirest;
import com.mashape.unirest.request.HttpRequest;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.Setter;
import lombok.val;
import org.deeplearning4j.nearestneighbor.model.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.serde.base64.Nd4jBase64;
import org.nd4j.shade.jackson.core.JsonProcessingException;
import java.io.IOException;
@AllArgsConstructor
public class NearestNeighborsClient {
private String url;
@Setter
@Getter
protected String authToken;
public NearestNeighborsClient(String url){
this(url, null);
}
static {
// Only one time
Unirest.setObjectMapper(new ObjectMapper() {
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
new org.nd4j.shade.jackson.databind.ObjectMapper();
public <T> T readValue(String value, Class<T> valueType) {
try {
return jacksonObjectMapper.readValue(value, valueType);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public String writeValue(Object value) {
try {
return jacksonObjectMapper.writeValueAsString(value);
} catch (JsonProcessingException e) {
throw new RuntimeException(e);
}
}
});
}
/**
* Runs knn on the given index
* with the given k (note that this is for data
* already within the existing dataset not new data)
* @param index the index of the
* EXISTING ndarray
* to run a search on
* @param k the number of results
* @return
* @throws Exception
*/
public NearestNeighborsResults knn(int index, int k) throws Exception {
NearestNeighborRequest request = new NearestNeighborRequest();
request.setInputIndex(index);
request.setK(k);
val req = Unirest.post(url + "/knn");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(request);
addAuthHeader(req);
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
return ret;
}
/**
* Run a k nearest neighbors search
* on a NEW data point
* @param k the number of results
* to retrieve
* @param arr the array to run the search on.
* Note that this must be a row vector
* @return
* @throws Exception
*/
public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception {
Base64NDArrayBody base64NDArrayBody =
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
val req = Unirest.post(url + "/knnnew");
req.header("accept", "application/json")
.header("Content-Type", "application/json").body(base64NDArrayBody);
addAuthHeader(req);
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
return ret;
}
/**
* Add the specified authentication header to the specified HttpRequest
*
* @param request HTTP Request to add the authentication header to
*/
protected HttpRequest addAuthHeader(HttpRequest request) {
if (authToken != null) {
request.header("authorization", "Bearer " + authToken);
}
return request;
}
}

View File

@ -1,61 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-nearestneighbors-model</name>
<dependencies>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
</profile>
</profiles>
</project>

View File

@ -1,38 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
@Builder
public class Base64NDArrayBody implements Serializable {
private String ndarray;
private int k;
private boolean forceFillK;
}

View File

@ -1,65 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
@Data
@AllArgsConstructor
@Builder
@NoArgsConstructor
public class BatchRecord implements Serializable {
private List<CSVRecord> records;
/**
* Add a record
* @param record
*/
public void add(CSVRecord record) {
if (records == null)
records = new ArrayList<>();
records.add(record);
}
/**
* Return a batch record based on a dataset
* @param dataSet the dataset to get the batch record for
* @return the batch record
*/
public static BatchRecord fromDataSet(DataSet dataSet) {
BatchRecord batchRecord = new BatchRecord();
for (int i = 0; i < dataSet.numExamples(); i++) {
batchRecord.add(CSVRecord.fromRow(dataSet.get(i)));
}
return batchRecord;
}
}

View File

@ -1,85 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.dataset.DataSet;
import java.io.Serializable;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class CSVRecord implements Serializable {
private String[] values;
/**
* Instantiate a csv record from a vector
* given either an input dataset and a
* one hot matrix, the index will be appended to
* the end of the record, or for regression
* it will append all values in the labels
* @param row the input vectors
* @return the record from this {@link DataSet}
*/
public static CSVRecord fromRow(DataSet row) {
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
//classification
CSVRecord record;
int idx = 0;
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
String[] values = new String[row.getFeatures().columns() + 1];
for (int i = 0; i < row.getFeatures().length(); i++) {
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
}
int maxIdx = 0;
for (int i = 0; i < row.getLabels().length(); i++) {
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
maxIdx = i;
}
}
values[idx++] = String.valueOf(maxIdx);
record = new CSVRecord(values);
}
//regression (any number of values)
else {
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
for (int i = 0; i < row.getFeatures().length(); i++) {
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
}
for (int i = 0; i < row.getLabels().length(); i++) {
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
}
record = new CSVRecord(values);
}
return record;
}
}

View File

@ -1,32 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.Data;
import java.io.Serializable;
@Data
public class NearestNeighborRequest implements Serializable {
private int k;
private int inputIndex;
}

View File

@ -1,37 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
@Data
@AllArgsConstructor
@NoArgsConstructor
public class NearestNeighborsResult {
public NearestNeighborsResult(int index, double distance) {
this(index, distance, null);
}
private int index;
private double distance;
private String label;
}

View File

@ -1,38 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.nearestneighbor.model;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
import java.util.List;
@Data
@Builder
@NoArgsConstructor
@AllArgsConstructor
public class NearestNeighborsResults implements Serializable {
private List<NearestNeighborsResult> results;
}

View File

@ -1,103 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ /* ******************************************************************************
~ *
~ *
~ * This program and the accompanying materials are made available under the
~ * terms of the Apache License, Version 2.0 which is available at
~ * https://www.apache.org/licenses/LICENSE-2.0.
~ *
~ * See the NOTICE file distributed with this work for additional
~ * information regarding copyright ownership.
~ * Unless required by applicable law or agreed to in writing, software
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ * License for the specific language governing permissions and limitations
~ * under the License.
~ *
~ * SPDX-License-Identifier: Apache-2.0
~ ******************************************************************************/
-->
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>nearestneighbor-core</artifactId>
<packaging>jar</packaging>
<name>nearestneighbor-core</name>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>joda-time</groupId>
<artifactId>joda-time</artifactId>
<version>2.10.3</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-native</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
<profile>
<id>test-nd4j-cuda-11.0</id>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-cuda-11.0</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
</profile>
</profiles>
</project>

View File

@ -1,218 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.algorithm;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.ClusterUtils;
import org.deeplearning4j.clustering.cluster.Point;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.deeplearning4j.clustering.iteration.IterationInfo;
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.util.MultiThreadUtils;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ExecutorService;
@Slf4j
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable {
private static final long serialVersionUID = 338231277453149972L;
private ClusteringStrategy clusteringStrategy;
private IterationHistory iterationHistory;
private int currentIteration = 0;
private ClusterSet clusterSet;
private List<Point> initialPoints;
private transient ExecutorService exec;
private boolean useKmeansPlusPlus;
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
this.clusteringStrategy = clusteringStrategy;
this.exec = MultiThreadUtils.newExecutorService();
this.useKmeansPlusPlus = useKmeansPlusPlus;
}
/**
*
* @param clusteringStrategy
* @return
*/
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
}
/**
*
* @param points
* @return
*/
public ClusterSet applyTo(List<Point> points) {
resetState(points);
initClusters(useKmeansPlusPlus);
iterations();
return clusterSet;
}
private void resetState(List<Point> points) {
this.iterationHistory = new IterationHistory();
this.currentIteration = 0;
this.clusterSet = null;
this.initialPoints = points;
}
/** Run clustering iterations until a
* termination condition is hit.
* This is done by first classifying all points,
* and then updating cluster centers based on
* those classified points
*/
private void iterations() {
int iterationCount = 0;
while ((clusteringStrategy.getTerminationCondition() != null
&& !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory))
|| iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
currentIteration++;
removePoints();
classifyPoints();
applyClusteringStrategy();
log.trace("Completed clustering iteration {}", ++iterationCount);
}
}
protected void classifyPoints() {
//Classify points. This also adds each point to the ClusterSet
ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec);
//Update the cluster centers, based on the points within each cluster
ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec);
iterationHistory.getIterationsInfos().put(currentIteration,
new IterationInfo(currentIteration, clusterSetInfo));
}
/**
* Initialize the
* cluster centers at random
*/
protected void initClusters(boolean kMeansPlusPlus) {
log.info("Generating initial clusters");
List<Point> points = new ArrayList<>(initialPoints);
//Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly)
val random = Nd4j.getRandom();
Distance distanceFn = clusteringStrategy.getDistanceFunction();
int initialClusterCount = clusteringStrategy.getInitialClusterCount();
clusterSet = new ClusterSet(distanceFn,
clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()});
clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size())));
//dxs: distances between
// each point and nearest cluster to that point
INDArray dxs = Nd4j.create(points.size());
dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE);
//Generate the initial cluster centers, by randomly selecting a point between 0 and max distance
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
double summed = Nd4j.sum(dxs).getDouble(0);
double r = kMeansPlusPlus ? random.nextDouble() * summed:
random.nextFloat() * dxs.maxNumber().doubleValue();
for (int i = 0; i < dxs.length(); i++) {
double distance = dxs.getDouble(i);
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
"function must return values >= 0, got distance %s for function s", distance, distanceFn);
if (dxs.getDouble(i) >= r) {
clusterSet.addNewClusterWithCenter(points.remove(i));
dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i));
break;
}
}
}
ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet);
iterationHistory.getIterationsInfos().put(currentIteration,
new IterationInfo(currentIteration, initialClusterSetInfo));
}
protected void applyClusteringStrategy() {
if (!isStrategyApplicableNow())
return;
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
if (!clusteringStrategy.isAllowEmptyClusters()) {
int removedCount = removeEmptyClusters(clusterSetInfo);
if (removedCount > 0) {
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT)
&& clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) {
int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo,
clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec);
if (splitCount > 0)
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
}
}
}
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION))
optimize();
}
protected void optimize() {
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy;
boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec);
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied);
}
private boolean isStrategyApplicableNow() {
return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0
&& clusteringStrategy.isOptimizationApplicableNow(iterationHistory);
}
protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
List<Cluster> removedClusters = clusterSet.removeEmptyClusters();
clusterSetInfo.removeClusterInfos(removedClusters);
return removedClusters.size();
}
protected void removePoints() {
clusterSet.removePoints();
}
}

View File

@ -1,38 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.algorithm;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import org.deeplearning4j.clustering.cluster.Point;
import java.util.List;
public interface ClusteringAlgorithm {
/**
* Apply a clustering
* algorithm for a given result
* @param points
* @return
*/
ClusterSet applyTo(List<Point> points);
}

View File

@ -1,41 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.algorithm;
public enum Distance {
EUCLIDEAN("euclidean"),
COSINE_DISTANCE("cosinedistance"),
COSINE_SIMILARITY("cosinesimilarity"),
MANHATTAN("manhattan"),
DOT("dot"),
JACCARD("jaccard"),
HAMMING("hamming");
private String functionName;
private Distance(String name) {
functionName = name;
}
@Override
public String toString() {
return functionName;
}
}

View File

@ -1,105 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
public class CentersHolder {
private INDArray centers;
private long index = 0;
protected transient ReduceOp op;
protected ArgMin imin;
protected transient INDArray distances;
protected transient INDArray argMin;
private long rows, cols;
public CentersHolder(long rows, long cols) {
this.rows = rows;
this.cols = cols;
}
public INDArray getCenters() {
return this.centers;
}
public synchronized void addCenter(INDArray pointView) {
if (centers == null)
this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols});
centers.putRow(index++, pointView);
}
public synchronized Pair<Double, Long> getCenterByMinDistance(Point point, Distance distanceFunction) {
if (distances == null)
distances = Nd4j.create(centers.dataType(), centers.rows());
if (argMin == null)
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
if (op == null) {
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
imin = new ArgMin(distances, argMin);
op.setZ(distances);
}
op.setY(point.getArray());
Nd4j.getExecutioner().exec(op);
Nd4j.getExecutioner().exec(imin);
Pair<Double, Long> result = new Pair<>();
result.setFirst(distances.getDouble(argMin.getLong(0)));
result.setSecond(argMin.getLong(0));
return result;
}
public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) {
if (distances == null)
distances = Nd4j.create(centers.dataType(), centers.rows());
if (argMin == null)
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
if (op == null) {
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
imin = new ArgMin(distances, argMin);
op.setZ(distances);
}
op.setY(point.getArray());
Nd4j.getExecutioner().exec(op);
Nd4j.getExecutioner().exec(imin);
System.out.println(distances);
return distances;
}
}

View File

@ -1,150 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import lombok.Data;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.UUID;
@Data
public class Cluster implements Serializable {
private String id = UUID.randomUUID().toString();
private String label;
private Point center;
private List<Point> points = Collections.synchronizedList(new ArrayList<Point>());
private boolean inverse = false;
private Distance distanceFunction;
public Cluster() {
super();
}
/**
*
* @param center
* @param distanceFunction
*/
public Cluster(Point center, Distance distanceFunction) {
this(center, false, distanceFunction);
}
/**
*
* @param center
* @param distanceFunction
*/
public Cluster(Point center, boolean inverse, Distance distanceFunction) {
this.distanceFunction = distanceFunction;
this.inverse = inverse;
setCenter(center);
}
/**
* Get the distance to the given
* point from the cluster
* @param point the point to get the distance for
* @return
*/
public double getDistanceToCenter(Point point) {
return Nd4j.getExecutioner().execAndReturn(
ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray()))
.getFinalResult().doubleValue();
}
/**
* Add a point to the cluster
* @param point
*/
public void addPoint(Point point) {
addPoint(point, true);
}
/**
* Add a point to the cluster
* @param point the point to add
* @param moveClusterCenter whether to update
* the cluster centroid or not
*/
public void addPoint(Point point, boolean moveClusterCenter) {
if (moveClusterCenter) {
if (isInverse()) {
center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1);
} else {
center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1);
}
}
getPoints().add(point);
}
/**
* Clear out the ponits
*/
public void removePoints() {
if (getPoints() != null)
getPoints().clear();
}
/**
* Whether the cluster is empty or not
* @return
*/
public boolean isEmpty() {
return points == null || points.isEmpty();
}
/**
* Return the point with the given id
* @param id
* @return
*/
public Point getPoint(String id) {
for (Point point : points)
if (id.equals(point.getId()))
return point;
return null;
}
/**
* Remove the point and return it
* @param id
* @return
*/
public Point removePoint(String id) {
Point removePoint = null;
for (Point point : points)
if (id.equals(point.getId()))
removePoint = point;
if (removePoint != null)
points.remove(removePoint);
return removePoint;
}
}

View File

@ -1,259 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import lombok.Data;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.Serializable;
import java.util.*;
@Data
public class ClusterSet implements Serializable {
private Distance distanceFunction;
private List<Cluster> clusters;
private CentersHolder centersHolder;
private Map<String, String> pointDistribution;
private boolean inverse;
public ClusterSet(boolean inverse) {
this(null, inverse, null);
}
public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) {
this.distanceFunction = distanceFunction;
this.inverse = inverse;
this.clusters = Collections.synchronizedList(new ArrayList<Cluster>());
this.pointDistribution = Collections.synchronizedMap(new HashMap<String, String>());
if (shape != null)
this.centersHolder = new CentersHolder(shape[0], shape[1]);
}
public boolean isInverse() {
return inverse;
}
/**
*
* @param center
* @return
*/
public Cluster addNewClusterWithCenter(Point center) {
Cluster newCluster = new Cluster(center, distanceFunction);
getClusters().add(newCluster);
setPointLocation(center, newCluster);
centersHolder.addCenter(center.getArray());
return newCluster;
}
/**
*
* @param point
* @return
*/
public PointClassification classifyPoint(Point point) {
return classifyPoint(point, true);
}
/**
*
* @param points
*/
public void classifyPoints(List<Point> points) {
classifyPoints(points, true);
}
/**
*
* @param points
* @param moveClusterCenter
*/
public void classifyPoints(List<Point> points, boolean moveClusterCenter) {
for (Point point : points)
classifyPoint(point, moveClusterCenter);
}
/**
*
* @param point
* @param moveClusterCenter
* @return
*/
public PointClassification classifyPoint(Point point, boolean moveClusterCenter) {
Pair<Cluster, Double> nearestCluster = nearestCluster(point);
Cluster newCluster = nearestCluster.getKey();
boolean locationChange = isPointLocationChange(point, newCluster);
addPointToCluster(point, newCluster, moveClusterCenter);
return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange);
}
private boolean isPointLocationChange(Point point, Cluster newCluster) {
if (!getPointDistribution().containsKey(point.getId()))
return true;
return !getPointDistribution().get(point.getId()).equals(newCluster.getId());
}
private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) {
cluster.addPoint(point, moveClusterCenter);
setPointLocation(point, cluster);
}
private void setPointLocation(Point point, Cluster cluster) {
pointDistribution.put(point.getId(), cluster.getId());
}
/**
*
* @param point
* @return
*/
public Pair<Cluster, Double> nearestCluster(Point point) {
/*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE;
double currentDistance;
for (Cluster cluster : getClusters()) {
currentDistance = cluster.getDistanceToCenter(point);
if (isInverse()) {
if (currentDistance > minDistance) {
minDistance = currentDistance;
nearestCluster = cluster;
}
} else {
if (currentDistance < minDistance) {
minDistance = currentDistance;
nearestCluster = cluster;
}
}
}*/
Pair<Double, Long> nearestCenterData = centersHolder.
getCenterByMinDistance(point, distanceFunction);
Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue());
double minDistance = nearestCenterData.getFirst();
return Pair.of(nearestCluster, minDistance);
}
/**
*
* @param m1
* @param m2
* @return
*/
public double getDistance(Point m1, Point m2) {
return Nd4j.getExecutioner()
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray()))
.getFinalResult().doubleValue();
}
/**
*
* @param point
* @return
*/
/*public double getDistanceFromNearestCluster(Point point) {
return nearestCluster(point).getValue();
}*/
/**
*
* @param clusterId
* @return
*/
public String getClusterCenterId(String clusterId) {
Point clusterCenter = getClusterCenter(clusterId);
return clusterCenter == null ? null : clusterCenter.getId();
}
/**
*
* @param clusterId
* @return
*/
public Point getClusterCenter(String clusterId) {
Cluster cluster = getCluster(clusterId);
return cluster == null ? null : cluster.getCenter();
}
/**
*
* @param id
* @return
*/
public Cluster getCluster(String id) {
for (int i = 0, j = clusters.size(); i < j; i++)
if (id.equals(clusters.get(i).getId()))
return clusters.get(i);
return null;
}
/**
*
* @return
*/
public int getClusterCount() {
return getClusters() == null ? 0 : getClusters().size();
}
/**
*
*/
public void removePoints() {
for (Cluster cluster : getClusters())
cluster.removePoints();
}
/**
*
* @param count
* @return
*/
public List<Cluster> getMostPopulatedClusters(int count) {
List<Cluster> mostPopulated = new ArrayList<>(clusters);
Collections.sort(mostPopulated, new Comparator<Cluster>() {
public int compare(Cluster o1, Cluster o2) {
return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
}
});
return mostPopulated.subList(0, count);
}
/**
*
* @return
*/
public List<Cluster> removeEmptyClusters() {
List<Cluster> emptyClusters = new ArrayList<>();
for (Cluster cluster : clusters)
if (cluster.isEmpty())
emptyClusters.add(cluster);
clusters.removeAll(emptyClusters);
return emptyClusters;
}
}

View File

@ -1,531 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.info.ClusterInfo;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
import org.deeplearning4j.clustering.util.MathUtils;
import org.deeplearning4j.clustering.util.MultiThreadUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.factory.Nd4j;
import java.util.*;
import java.util.concurrent.ExecutorService;
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Slf4j
public class ClusterUtils {
/** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */
public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points,
ExecutorService executorService) {
final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
List<Runnable> tasks = new ArrayList<>();
for (final Point point : points) {
//tasks.add(new Runnable() {
// public void run() {
try {
PointClassification result = classifyPoint(clusterSet, point);
if (result.isNewLocation())
clusterSetInfo.getPointLocationChange().incrementAndGet();
clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter()
.put(point.getId(), result.getDistanceFromCenter());
} catch (Throwable t) {
log.warn("Error classifying point", t);
}
// }
}
//MultiThreadUtils.parallelTasks(tasks, executorService);
return clusterSetInfo;
}
public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
return clusterSet.classifyPoint(point, false);
}
public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
ExecutorService executorService) {
List<Runnable> tasks = new ArrayList<>();
int nClusters = clusterSet.getClusterCount();
for (int i = 0; i < nClusters; i++) {
final Cluster cluster = clusterSet.getClusters().get(i);
//tasks.add(new Runnable() {
// public void run() {
try {
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
refreshClusterCenter(cluster, clusterInfo);
deriveClusterInfoDistanceStatistics(clusterInfo);
} catch (Throwable t) {
log.warn("Error refreshing cluster centers", t);
}
// }
//});
}
//MultiThreadUtils.parallelTasks(tasks, executorService);
}
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
int pointsCount = cluster.getPoints().size();
if (pointsCount == 0)
return;
Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length()));
for (Point point : cluster.getPoints()) {
INDArray arr = point.getArray();
if (cluster.isInverse())
center.getArray().subi(arr);
else
center.getArray().addi(arr);
}
center.getArray().divi(pointsCount);
cluster.setCenter(center);
}
/**
*
* @param info
*/
public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
int pointCount = info.getPointDistancesFromCenter().size();
if (pointCount == 0)
return;
double[] distances =
ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {}));
double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances);
double total = MathUtils.sum(distances);
info.setMaxPointDistanceFromCenter(max);
info.setTotalPointDistanceFromCenter(total);
info.setAveragePointDistanceFromCenter(total / pointCount);
info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
}
/**
*
* @param clusterSet
* @param points
* @param previousDxs
* @param executorService
* @return
*/
public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet,
final List<Point> points, INDArray previousDxs, ExecutorService executorService) {
final int pointsCount = points.size();
final INDArray dxs = Nd4j.create(pointsCount);
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < pointsCount; i++) {
final int i2 = i;
//tasks.add(new Runnable() {
// public void run() {
try {
Point point = points.get(i2);
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
: Math.pow(newCluster.getDistanceToCenter(point), 2);
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
} catch (Throwable t) {
log.warn("Error computing squared distance from nearest cluster", t);
}
// }
//});
}
//MultiThreadUtils.parallelTasks(tasks, executorService);
for (int i = 0; i < pointsCount; i++) {
double previousMinDistance = previousDxs.getDouble(i);
if (clusterSet.isInverse()) {
if (dxs.getDouble(i) < previousMinDistance) {
dxs.putScalar(i, previousMinDistance);
}
} else if (dxs.getDouble(i) > previousMinDistance)
dxs.putScalar(i, previousMinDistance);
}
return dxs;
}
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
final List<Point> points, INDArray previousDxs) {
final int pointsCount = points.size();
final INDArray dxs = Nd4j.create(pointsCount);
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
Double sum = new Double(0);
for (int i = 0; i < pointsCount; i++) {
Point point = points.get(i);
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
sum += dist;
dxs.putScalar(i, sum);
}
return dxs;
}
/**
*
* @param clusterSet
* @return
*/
public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
ExecutorService executor = MultiThreadUtils.newExecutorService();
ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor);
executor.shutdownNow();
return info;
}
public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
int clusterCount = clusterSet.getClusterCount();
List<Runnable> tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) {
final Cluster cluster = clusterSet.getClusters().get(i);
//tasks.add(new Runnable() {
// public void run() {
try {
info.getClustersInfos().put(cluster.getId(),
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
} catch (Throwable t) {
log.warn("Error computing cluster set info", t);
}
//}
//});
}
//MultiThreadUtils.parallelTasks(tasks, executorService);
//tasks = new ArrayList<>();
for (int i = 0; i < clusterCount; i++) {
final int clusterIdx = i;
final Cluster fromCluster = clusterSet.getClusters().get(i);
//tasks.add(new Runnable() {
//public void run() {
try {
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
Cluster toCluster = clusterSet.getClusters().get(k);
double distance = Nd4j.getExecutioner()
.execAndReturn(ClusterUtils.createDistanceFunctionOp(
clusterSet.getDistanceFunction(),
fromCluster.getCenter().getArray(),
toCluster.getCenter().getArray()))
.getFinalResult().doubleValue();
info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(),
distance);
}
} catch (Throwable t) {
log.warn("Error computing distances", t);
}
// }
//});
}
//MultiThreadUtils.parallelTasks(tasks, executorService);
return info;
}
/**
*
* @param cluster
* @param distanceFunction
* @return
*/
public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) {
ClusterInfo info = new ClusterInfo(cluster.isInverse(), true);
for (int i = 0, j = cluster.getPoints().size(); i < j; i++) {
Point point = cluster.getPoints().get(i);
//shouldn't need to inverse here. other parts of
//the code should interpret the "distance" or score here
double distance = Nd4j.getExecutioner()
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction,
cluster.getCenter().getArray(), point.getArray()))
.getFinalResult().doubleValue();
info.getPointDistancesFromCenter().put(point.getId(), distance);
double diff = info.getTotalPointDistanceFromCenter() + distance;
info.setTotalPointDistanceFromCenter(diff);
}
if (!cluster.getPoints().isEmpty())
info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size());
return info;
}
/**
*
* @param optimization
* @param clusterSet
* @param clusterSetInfo
* @param executor
* @return
*/
public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, ExecutorService executor) {
if (optimization.isClusteringOptimizationType(
ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet,
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
return splitCount > 0;
}
if (optimization.isClusteringOptimizationType(
ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet,
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
return splitCount > 0;
}
return false;
}
/**
*
* @param clusterSet
* @param info
* @param count
* @return
*/
public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info,
int count) {
List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters());
Collections.sort(clusters, new Comparator<Cluster>() {
public int compare(Cluster o1, Cluster o2) {
Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
int comp = o1TotalDistance.compareTo(o2TotalDistance);
return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp;
}
});
return clusters.subList(0, count);
}
/**
*
* @param clusterSet
* @param info
* @param maximumAverageDistance
* @return
*/
public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
final ClusterSetInfo info, double maximumAverageDistance) {
List<Cluster> clusters = new ArrayList<>();
for (Cluster cluster : clusterSet.getClusters()) {
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
if (clusterInfo != null) {
//distances
if (clusterInfo.isInverse()) {
if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance)
clusters.add(cluster);
} else {
if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)
clusters.add(cluster);
}
}
}
return clusters;
}
/**
*
* @param clusterSet
* @param info
* @param maximumDistance
* @return
*/
public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
final ClusterSetInfo info, double maximumDistance) {
List<Cluster> clusters = new ArrayList<>();
for (Cluster cluster : clusterSet.getClusters()) {
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
if (clusterInfo != null) {
if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) {
clusters.add(cluster);
} else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) {
clusters.add(cluster);
}
}
}
return clusters;
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param count
* @param executorService
* @return
*/
public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
ExecutorService executorService) {
List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
return clustersToSplit.size();
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param maxWithinClusterDistance
* @param executorService
* @return
*/
public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
maxWithinClusterDistance);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
return clustersToSplit.size();
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param maxWithinClusterDistance
* @param executorService
* @return
*/
public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet,
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
maxWithinClusterDistance);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
return clustersToSplit.size();
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param count
* @param executorService
*/
public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
ExecutorService executorService) {
List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param clusters
* @param maxDistance
* @param executorService
*/
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
final Random random = new Random();
List<Runnable> tasks = new ArrayList<>();
for (final Cluster cluster : clusters) {
tasks.add(new Runnable() {
public void run() {
try {
ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
int rank = Math.min(fartherPoints.size(), 3);
String pointId = fartherPoints.get(random.nextInt(rank));
Point point = cluster.removePoint(pointId);
clusterSet.addNewClusterWithCenter(point);
} catch (Throwable t) {
log.warn("Error splitting clusters", t);
}
}
});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
}
/**
*
* @param clusterSet
* @param clusterSetInfo
* @param clusters
* @param executorService
*/
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
List<Cluster> clusters, ExecutorService executorService) {
final Random random = new Random();
List<Runnable> tasks = new ArrayList<>();
for (final Cluster cluster : clusters) {
tasks.add(new Runnable() {
public void run() {
try {
Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
clusterSet.addNewClusterWithCenter(point);
} catch (Throwable t) {
log.warn("Error Splitting clusters (2)", t);
}
}
});
}
MultiThreadUtils.parallelTasks(tasks, executorService);
}
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){
val op = createDistanceFunctionOp(distanceFunction, x, y);
op.setDimensions(dimensions);
return op;
}
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){
switch (distanceFunction){
case COSINE_DISTANCE:
return new CosineDistance(x,y);
case COSINE_SIMILARITY:
return new CosineSimilarity(x,y);
case DOT:
return new Dot(x,y);
case EUCLIDEAN:
return new EuclideanDistance(x,y);
case JACCARD:
return new JaccardDistance(x,y);
case MANHATTAN:
return new ManhattanDistance(x,y);
default:
throw new IllegalStateException("Unknown distance function: " + distanceFunction);
}
}
}

View File

@ -1,107 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import lombok.AccessLevel;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.UUID;
/**
*
*/
@Data
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class Point implements Serializable {
private static final long serialVersionUID = -6658028541426027226L;
private String id = UUID.randomUUID().toString();
private String label;
private INDArray array;
/**
*
* @param array
*/
public Point(INDArray array) {
super();
this.array = array;
}
/**
*
* @param id
* @param array
*/
public Point(String id, INDArray array) {
super();
this.id = id;
this.array = array;
}
public Point(String id, String label, double[] data) {
this(id, label, Nd4j.create(data));
}
public Point(String id, String label, INDArray array) {
super();
this.id = id;
this.label = label;
this.array = array;
}
/**
*
* @param matrix
* @return
*/
public static List<Point> toPoints(INDArray matrix) {
List<Point> arr = new ArrayList<>(matrix.rows());
for (int i = 0; i < matrix.rows(); i++) {
arr.add(new Point(matrix.getRow(i)));
}
return arr;
}
/**
*
* @param vectors
* @return
*/
public static List<Point> toPoints(List<INDArray> vectors) {
List<Point> points = new ArrayList<>();
for (INDArray vector : vectors)
points.add(new Point(vector));
return points;
}
}

View File

@ -1,40 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor
public class PointClassification implements Serializable {
private Cluster cluster;
private double distanceFromCenter;
private boolean newLocation;
}

View File

@ -1,37 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.condition;
import org.deeplearning4j.clustering.iteration.IterationHistory;
/**
*
*/
public interface ClusteringAlgorithmCondition {
/**
*
* @param iterationHistory
* @return
*/
boolean isSatisfied(IterationHistory iterationHistory);
}

View File

@ -1,69 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.condition;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.LessThan;
import java.io.Serializable;
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor(access = AccessLevel.PROTECTED)
public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable {
private Condition convergenceCondition;
private double pointsDistributionChangeRate;
/**
*
* @param pointsDistributionChangeRate
* @return
*/
public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) {
Condition condition = new LessThan(pointsDistributionChangeRate);
return new ConvergenceCondition(condition, pointsDistributionChangeRate);
}
/**
*
* @param iterationHistory
* @return
*/
public boolean isSatisfied(IterationHistory iterationHistory) {
int iterationCount = iterationHistory.getIterationCount();
if (iterationCount <= 1)
return false;
double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get();
variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount();
return convergenceCondition.apply(variation);
}
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.condition;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual;
import java.io.Serializable;
/**
*
*/
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable {
private Condition iterationCountCondition;
protected FixedIterationCountCondition(int initialClusterCount) {
iterationCountCondition = new GreaterThanOrEqual(initialClusterCount);
}
/**
*
* @param iterationCount
* @return
*/
public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) {
return new FixedIterationCountCondition(iterationCount);
}
/**
*
* @param iterationHistory
* @return
*/
public boolean isSatisfied(IterationHistory iterationHistory) {
return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount());
}
}

View File

@ -1,82 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.condition;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.nd4j.linalg.indexing.conditions.Condition;
import org.nd4j.linalg.indexing.conditions.LessThan;
import java.io.Serializable;
/**
*
*/
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor
public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable {
private Condition varianceVariationCondition;
private int period;
/**
*
* @param varianceVariation
* @param period
* @return
*/
public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) {
Condition condition = new LessThan(varianceVariation);
return new VarianceVariationCondition(condition, period);
}
/**
*
* @param iterationHistory
* @return
*/
public boolean isSatisfied(IterationHistory iterationHistory) {
if (iterationHistory.getIterationCount() <= period)
return false;
for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) {
double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo()
.getPointDistanceFromClusterVariance();
variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
.getPointDistanceFromClusterVariance();
variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
.getPointDistanceFromClusterVariance();
if (!varianceVariationCondition.apply(variation))
return false;
}
return true;
}
}

View File

@ -1,114 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.info;
import lombok.Data;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
/**
*
*/
@Data
public class ClusterInfo implements Serializable {
private double averagePointDistanceFromCenter;
private double maxPointDistanceFromCenter;
private double pointDistanceFromCenterVariance;
private double totalPointDistanceFromCenter;
private boolean inverse;
private Map<String, Double> pointDistancesFromCenter = new ConcurrentHashMap<>();
public ClusterInfo(boolean inverse) {
this(false, inverse);
}
/**
*
* @param threadSafe
*/
public ClusterInfo(boolean threadSafe, boolean inverse) {
super();
this.inverse = inverse;
if (threadSafe) {
pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter);
}
}
/**
*
* @return
*/
public Set<Map.Entry<String, Double>> getSortedPointDistancesFromCenter() {
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
int res = e1.getValue().compareTo(e2.getValue());
return res != 0 ? res : 1;
}
});
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
return sortedEntries;
}
/**
*
* @return
*/
public Set<Map.Entry<String, Double>> getReverseSortedPointDistancesFromCenter() {
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
@Override
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
int res = e1.getValue().compareTo(e2.getValue());
return -(res != 0 ? res : 1);
}
});
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
return sortedEntries;
}
/**
*
* @param maxDistance
* @return
*/
public List<String> getPointsFartherFromCenterThan(double maxDistance) {
Set<Map.Entry<String, Double>> sorted = getReverseSortedPointDistancesFromCenter();
List<String> ids = new ArrayList<>();
for (Map.Entry<String, Double> entry : sorted) {
if (inverse && entry.getValue() < -maxDistance) {
if (entry.getValue() < -maxDistance)
break;
}
else if (entry.getValue() > maxDistance)
break;
ids.add(entry.getKey());
}
return ids;
}
}

View File

@ -1,142 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.info;
import org.nd4j.shade.guava.collect.HashBasedTable;
import org.nd4j.shade.guava.collect.Table;
import org.deeplearning4j.clustering.cluster.Cluster;
import org.deeplearning4j.clustering.cluster.ClusterSet;
import java.io.Serializable;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
public class ClusterSetInfo implements Serializable {
private Map<String, ClusterInfo> clustersInfos = new HashMap<>();
private Table<String, String, Double> distancesBetweenClustersCenters = HashBasedTable.create();
private AtomicInteger pointLocationChange;
private boolean threadSafe;
private boolean inverse;
public ClusterSetInfo(boolean inverse) {
this(inverse, false);
}
/**
*
* @param inverse
* @param threadSafe
*/
public ClusterSetInfo(boolean inverse, boolean threadSafe) {
this.pointLocationChange = new AtomicInteger(0);
this.threadSafe = threadSafe;
this.inverse = inverse;
if (threadSafe) {
clustersInfos = Collections.synchronizedMap(clustersInfos);
}
}
/**
*
* @param clusterSet
* @param threadSafe
* @return
*/
public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) {
ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe);
for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++)
info.addClusterInfo(clusterSet.getClusters().get(i).getId());
return info;
}
public void removeClusterInfos(List<Cluster> clusters) {
for (Cluster cluster : clusters) {
clustersInfos.remove(cluster.getId());
}
}
public ClusterInfo addClusterInfo(String clusterId) {
ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe);
clustersInfos.put(clusterId, clusterInfo);
return clusterInfo;
}
public ClusterInfo getClusterInfo(String clusterId) {
return clustersInfos.get(clusterId);
}
public double getAveragePointDistanceFromClusterCenter() {
if (clustersInfos == null || clustersInfos.isEmpty())
return 0;
double average = 0;
for (ClusterInfo info : clustersInfos.values())
average += info.getAveragePointDistanceFromCenter();
return average / clustersInfos.size();
}
public double getPointDistanceFromClusterVariance() {
if (clustersInfos == null || clustersInfos.isEmpty())
return 0;
double average = 0;
for (ClusterInfo info : clustersInfos.values())
average += info.getPointDistanceFromCenterVariance();
return average / clustersInfos.size();
}
public int getPointsCount() {
int count = 0;
for (ClusterInfo clusterInfo : clustersInfos.values())
count += clusterInfo.getPointDistancesFromCenter().size();
return count;
}
public Map<String, ClusterInfo> getClustersInfos() {
return clustersInfos;
}
public void setClustersInfos(Map<String, ClusterInfo> clustersInfos) {
this.clustersInfos = clustersInfos;
}
public Table<String, String, Double> getDistancesBetweenClustersCenters() {
return distancesBetweenClustersCenters;
}
public void setDistancesBetweenClustersCenters(Table<String, String, Double> interClusterDistances) {
this.distancesBetweenClustersCenters = interClusterDistances;
}
public AtomicInteger getPointLocationChange() {
return pointLocationChange;
}
public void setPointLocationChange(AtomicInteger pointLocationChange) {
this.pointLocationChange = pointLocationChange;
}
}

View File

@ -1,72 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.iteration;
import lombok.Getter;
import lombok.Setter;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
public class IterationHistory implements Serializable {
@Getter
@Setter
private Map<Integer, IterationInfo> iterationsInfos = new HashMap<>();
/**
*
* @return
*/
public ClusterSetInfo getMostRecentClusterSetInfo() {
IterationInfo iterationInfo = getMostRecentIterationInfo();
return iterationInfo == null ? null : iterationInfo.getClusterSetInfo();
}
/**
*
* @return
*/
public IterationInfo getMostRecentIterationInfo() {
return getIterationInfo(getIterationCount() - 1);
}
/**
*
* @return
*/
public int getIterationCount() {
return getIterationsInfos().size();
}
/**
*
* @param iterationIdx
* @return
*/
public IterationInfo getIterationInfo(int iterationIdx) {
return getIterationsInfos().get(iterationIdx);
}
}

View File

@ -1,49 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.iteration;
import lombok.AccessLevel;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.deeplearning4j.clustering.info.ClusterSetInfo;
import java.io.Serializable;
@Data
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class IterationInfo implements Serializable {
private int index;
private ClusterSetInfo clusterSetInfo;
private boolean strategyApplied;
public IterationInfo(int index) {
super();
this.index = index;
}
public IterationInfo(int index, ClusterSetInfo clusterSetInfo) {
super();
this.index = index;
this.clusterSetInfo = clusterSetInfo;
}
}

View File

@ -1,142 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.kdtree;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.Serializable;
public class HyperRect implements Serializable {
//private List<Interval> points;
private float[] lowerEnds;
private float[] higherEnds;
private INDArray lowerEndsIND;
private INDArray higherEndsIND;
public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
this.lowerEnds = new float[lowerEndsIn.length];
this.higherEnds = new float[lowerEndsIn.length];
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
higherEndsIND = Nd4j.createFromArray(higherEnds);
}
public HyperRect(float[] point) {
this(point, point);
}
public HyperRect(Pair<float[], float[]> ends) {
this(ends.getFirst(), ends.getSecond());
}
public void enlargeTo(INDArray point) {
float[] pointAsArray = point.toFloatVector();
for (int i = 0; i < lowerEnds.length; i++) {
float p = pointAsArray[i];
if (lowerEnds[i] > p)
lowerEnds[i] = p;
else if (higherEnds[i] < p)
higherEnds[i] = p;
}
}
public static Pair<float[],float[]> point(INDArray vector) {
Pair<float[],float[]> ret = new Pair<>();
float[] curr = new float[(int)vector.length()];
for (int i = 0; i < vector.length(); i++) {
curr[i] = vector.getFloat(i);
}
ret.setFirst(curr);
ret.setSecond(curr);
return ret;
}
/*public List<Boolean> contains(INDArray hPoint) {
List<Boolean> ret = new ArrayList<>();
for (int i = 0; i < hPoint.length(); i++) {
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
higherEnds[i] >= hPoint.getDouble(i));
}
return ret;
}*/
public double minDistance(INDArray hPoint, INDArray output) {
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
return output.getFloat(0);
/*double ret = 0.0;
double[] pointAsArray = hPoint.toDoubleVector();
for (int i = 0; i < pointAsArray.length; i++) {
double p = pointAsArray[i];
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
if (p < lowerEnds[i])
ret += Math.pow((p - lowerEnds[i]), 2);
else
ret += Math.pow((p - higherEnds[i]), 2);
}
}
ret = Math.pow(ret, 0.5);
return ret;*/
}
public HyperRect getUpper(INDArray hPoint, int desc) {
//Interval interval = points.get(desc);
float higher = higherEnds[desc];
float d = hPoint.getFloat(desc);
if (higher < d)
return null;
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
if (ret.lowerEnds[desc] < d)
ret.lowerEnds[desc] = d;
return ret;
}
public HyperRect getLower(INDArray hPoint, int desc) {
//Interval interval = points.get(desc);
float lower = lowerEnds[desc];
float d = hPoint.getFloat(desc);
if (lower > d)
return null;
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
//Interval i2 = ret.points.get(desc);
if (ret.higherEnds[desc] > d)
ret.higherEnds[desc] = d;
return ret;
}
@Override
public String toString() {
String retVal = "";
retVal += "[";
for (int i = 0; i < lowerEnds.length; ++i) {
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
}
retVal += "]";
return retVal;
}
}

View File

@ -1,370 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.kdtree;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
public class KDTree implements Serializable {
private KDNode root;
private int dims = 100;
public final static int GREATER = 1;
public final static int LESS = 0;
private int size = 0;
private HyperRect rect;
public KDTree(int dims) {
this.dims = dims;
}
/**
* Insert a point in to the tree
* @param point the point to insert
*/
public void insert(INDArray point) {
if (!point.isVector() || point.length() != dims)
throw new IllegalArgumentException("Point must be a vector of length " + dims);
if (root == null) {
root = new KDNode(point);
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
} else {
int disc = 0;
KDNode node = root;
KDNode insert = new KDNode(point);
int successor;
while (true) {
//exactly equal
INDArray pt = node.getPoint();
INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z();
if (countEq.getInt(0) == 0) {
return;
} else {
successor = successor(node, point, disc);
KDNode child;
if (successor < 1)
child = node.getLeft();
else
child = node.getRight();
if (child == null)
break;
disc = (disc + 1) % dims;
node = child;
}
}
if (successor < 1)
node.setLeft(insert);
else
node.setRight(insert);
rect.enlargeTo(point);
insert.setParent(node);
}
size++;
}
public INDArray delete(INDArray point) {
KDNode node = root;
int _disc = 0;
while (node != null) {
if (node.point == point)
break;
int successor = successor(node, point, _disc);
if (successor < 1)
node = node.getLeft();
else
node = node.getRight();
_disc = (_disc + 1) % dims;
}
if (node != null) {
if (node == root) {
root = delete(root, _disc);
} else
node = delete(node, _disc);
size--;
if (size == 1) {
rect = new HyperRect(HyperRect.point(point));
} else if (size == 0)
rect = null;
}
return node.getPoint();
}
// Share this data for recursive calls of "knn"
private float currentDistance;
private INDArray currentPoint;
private INDArray minDistance = Nd4j.scalar(0.f);
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
List<Pair<Float, INDArray>> best = new ArrayList<>();
currentDistance = distance;
currentPoint = point;
knn(root, rect, best, 0);
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
@Override
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
return Float.compare(o1.getKey(), o2.getKey());
}
});
return best;
}
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
return;
int _discNext = (_disc + 1) % dims;
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
.floatValue();
if (distance <= currentDistance) {
best.add(Pair.of(distance, node.getPoint()));
}
HyperRect lower = rect.getLower(node.point, _disc);
HyperRect upper = rect.getUpper(node.point, _disc);
knn(node.getLeft(), lower, best, _discNext);
knn(node.getRight(), upper, best, _discNext);
}
/**
* Query for nearest neighbor. Returns the distance and point
* @param point the point to query for
* @return
*/
public Pair<Double, INDArray> nn(INDArray point) {
return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0);
}
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
int _disc) {
if (node == null || rect.minDistance(point, minDistance) > dist)
return Pair.of(Double.POSITIVE_INFINITY, null);
int _discNext = (_disc + 1) % dims;
double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue();
if (dist2 < dist) {
best = node.getPoint();
dist = dist2;
}
HyperRect lower = rect.getLower(node.point, _disc);
HyperRect upper = rect.getUpper(node.point, _disc);
if (point.getDouble(_disc) < node.point.getDouble(_disc)) {
Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext);
Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext);
if (left.getKey() < dist)
return left;
else if (right.getKey() < dist)
return right;
} else {
Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext);
Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext);
if (left.getKey() < dist)
return left;
else if (right.getKey() < dist)
return right;
}
return Pair.of(dist, best);
}
private KDNode delete(KDNode delete, int _disc) {
if (delete.getLeft() != null && delete.getRight() != null) {
if (delete.getParent() != null) {
if (delete.getParent().getLeft() == delete)
delete.getParent().setLeft(null);
else
delete.getParent().setRight(null);
}
return null;
}
int disc = _disc;
_disc = (_disc + 1) % dims;
Pair<KDNode, Integer> qd = null;
if (delete.getRight() != null) {
qd = min(delete.getRight(), disc, _disc);
} else if (delete.getLeft() != null)
qd = max(delete.getLeft(), disc, _disc);
if (qd == null) {// is leaf
return null;
}
delete.point = qd.getKey().point;
KDNode qFather = qd.getKey().getParent();
if (qFather.getLeft() == qd.getKey()) {
qFather.setLeft(delete(qd.getKey(), disc));
} else if (qFather.getRight() == qd.getKey()) {
qFather.setRight(delete(qd.getKey(), disc));
}
return delete;
}
private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) {
int discNext = (_disc + 1) % dims;
if (_disc == disc) {
KDNode child = node.getLeft();
if (child != null) {
return max(child, disc, discNext);
}
} else if (node.getLeft() != null || node.getRight() != null) {
Pair<KDNode, Integer> left = null, right = null;
if (node.getLeft() != null)
left = max(node.getLeft(), disc, discNext);
if (node.getRight() != null)
right = max(node.getRight(), disc, discNext);
if (left != null && right != null) {
double pointLeft = left.getKey().getPoint().getDouble(disc);
double pointRight = right.getKey().getPoint().getDouble(disc);
if (pointLeft > pointRight)
return left;
else
return right;
} else if (left != null)
return left;
else
return right;
}
return Pair.of(node, _disc);
}
private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) {
int discNext = (_disc + 1) % dims;
if (_disc == disc) {
KDNode child = node.getLeft();
if (child != null) {
return min(child, disc, discNext);
}
} else if (node.getLeft() != null || node.getRight() != null) {
Pair<KDNode, Integer> left = null, right = null;
if (node.getLeft() != null)
left = min(node.getLeft(), disc, discNext);
if (node.getRight() != null)
right = min(node.getRight(), disc, discNext);
if (left != null && right != null) {
double pointLeft = left.getKey().getPoint().getDouble(disc);
double pointRight = right.getKey().getPoint().getDouble(disc);
if (pointLeft < pointRight)
return left;
else
return right;
} else if (left != null)
return left;
else
return right;
}
return Pair.of(node, _disc);
}
/**
* The number of elements in the tree
* @return the number of elements in the tree
*/
public int size() {
return size;
}
private int successor(KDNode node, INDArray point, int disc) {
for (int i = disc; i < dims; i++) {
double pointI = point.getDouble(i);
double nodePointI = node.getPoint().getDouble(i);
if (pointI < nodePointI)
return LESS;
else if (pointI > nodePointI)
return GREATER;
}
throw new IllegalStateException("Point is equal!");
}
private static class KDNode {
private INDArray point;
private KDNode left, right, parent;
public KDNode(INDArray point) {
this.point = point;
}
public INDArray getPoint() {
return point;
}
public KDNode getLeft() {
return left;
}
public void setLeft(KDNode left) {
this.left = left;
}
public KDNode getRight() {
return right;
}
public void setRight(KDNode right) {
this.right = right;
}
public KDNode getParent() {
return parent;
}
public void setParent(KDNode parent) {
this.parent = parent;
}
}
}

View File

@ -1,109 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.kmeans;
import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy;
public class KMeansClustering extends BaseClusteringAlgorithm {
private static final long serialVersionUID = 8476951388145944776L;
private static final double VARIATION_TOLERANCE= 1e-4;
/**
*
* @param clusteringStrategy
*/
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
super(clusteringStrategy, useKMeansPlusPlus);
}
/**
* Setup a kmeans instance
* @param clusterCount the number of clusters
* @param maxIterationCount the max number of iterations
* to run kmeans
* @param distanceFunction the distance function to use for grouping
* @return
*/
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
boolean inverse, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy =
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
/**
*
* @param clusterCount
* @param minDistributionVariationRate
* @param distanceFunction
* @param allowEmptyClusters
* @return
*/
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
/**
* Setup a kmeans instance
* @param clusterCount the number of clusters
* @param maxIterationCount the max number of iterations
* to run kmeans
* @param distanceFunction the distance function to use for grouping
* @return
*/
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
}
/**
*
* @param clusterCount
* @param minDistributionVariationRate
* @param distanceFunction
* @param allowEmptyClusters
* @return
*/
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
}
}

View File

@ -1,88 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.lsh;
import org.nd4j.linalg.api.ndarray.INDArray;
public interface LSH {
/**
* Returns an instance of the distance measure associated to the LSH family of this implementation.
* Beware, hashing families and their amplification constructs are distance-specific.
*/
String getDistanceMeasure();
/**
* Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction
*
* denoting hashLength by h,
* amplifies a (d1, d2, p1, p2) hash family into a
* (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h)
*
* @return the length of the hash in the AND construction used by this index
*/
int getHashLength();
/**
*
* denoting numTables by n,
* amplifies a (d1, d2, p1, p2) hash family into a
* (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n)
*
* @return the # of hash tables in the OR construction used by this index
*/
int getNumTables();
/**
* @return The dimension of the index vectors and queries
*/
int getInDimension();
/**
* Populates the index with data vectors.
* @param data the vectors to index
*/
void makeIndex(INDArray data);
/**
* Returns the set of all vectors that could approximately be considered negihbors of the query,
* without selection on the basis of distance or number of neighbors.
* @param query a vector to find neighbors for
* @return its approximate neighbors, unfiltered
*/
INDArray bucket(INDArray query);
/**
* Returns the approximate neighbors within a distance bound.
* @param query a vector to find neighbors for
* @param maxRange the maximum distance between results and the query
* @return approximate neighbors within the distance bounds
*/
INDArray search(INDArray query, double maxRange);
/**
* Returns the approximate neighbors within a k-closest bound
* @param query a vector to find neighbors for
* @param k the maximum number of closest neighbors to return
* @return at most k neighbors of the query, ordered by increasing distance
*/
INDArray search(INDArray query, int k);
}

View File

@ -1,227 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.lsh;
import lombok.Getter;
import lombok.val;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
import org.nd4j.linalg.api.rng.Random;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
import org.nd4j.linalg.ops.transforms.Transforms;
import java.util.Arrays;
public class RandomProjectionLSH implements LSH {
@Override
public String getDistanceMeasure(){
return "cosinedistance";
}
@Getter private int hashLength;
@Getter private int numTables;
@Getter private int inDimension;
@Getter private double radius;
INDArray randomProjection;
INDArray index;
INDArray indexData;
private INDArray gaussianRandomMatrix(int[] shape, Random rng){
INDArray res = Nd4j.create(shape);
GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
Nd4j.getExecutioner().exec(op1, rng);
return res;
}
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){
this(hashLength, numTables, inDimension, radius, Nd4j.getRandom());
}
/**
* Creates a locality-sensitive hashing index for the cosine distance,
* a (d1, d2, (180 d1)/180,(180 d2)/180)-sensitive hash family before amplification
*
* @param hashLength the length of the compared hash in an AND construction,
* @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple
* probes of Panigraphi (op. cit).
* @param inDimension the dimendionality of the points being indexed
* @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction
* @param rng a Random object to draw samples from
*/
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){
this.hashLength = hashLength;
this.numTables = numTables;
this.inDimension = inDimension;
this.radius = radius;
randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng);
}
/**
* This picks uniformaly distributed random points on the unit of a sphere using the method of:
*
* An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere
* JS Hicks, RF Wheeling - Communications of the ACM, 1959
* @param data a query to generate multiple probes for
* @return `numTables`
*/
public INDArray entropy(INDArray data){
INDArray data2 =
Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius));
INDArray norms = Nd4j.norm2(data2.dup(), -1);
Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms);
data2.diviColumnVector(norms);
data2.addiRowVector(data);
return data2;
}
/**
* Returns hash values for a particular query
* @param data a query vector
* @return its hashed value
*/
public INDArray hash(INDArray data) {
if (data.shape()[1] != inDimension){
throw new ND4JIllegalStateException(
String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d",
Arrays.toString(data.shape()), inDimension));
}
INDArray projected = data.mmul(randomProjection);
INDArray res = Nd4j.getExecutioner().exec(new Sign(projected));
return res;
}
/**
* Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it.
* @param data the vectors to index
*/
@Override
public void makeIndex(INDArray data) {
index = hash(data);
indexData = data;
}
// data elements in the same bucket as the query, without entropy
INDArray rawBucketOf(INDArray query){
INDArray pattern = hash(query);
INDArray res = Nd4j.zeros(DataType.BOOL, index.shape());
Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1));
return res.castTo(Nd4j.defaultFloatingPointType()).min(-1);
}
@Override
public INDArray bucket(INDArray query) {
INDArray queryRes = rawBucketOf(query);
if(numTables > 1) {
INDArray entropyQueries = entropy(query);
// loop, addi + conditionalreplace -> poor man's OR function
for (int i = 0; i < numTables; i++) {
INDArray row = entropyQueries.getRow(i, true);
queryRes.addi(rawBucketOf(row));
}
BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0));
}
return queryRes;
}
// data elements in the same entropy bucket as the query,
INDArray bucketData(INDArray query){
INDArray mask = bucket(query);
int nRes = mask.sum(0).getInt(0);
INDArray res = Nd4j.create(new int[] {nRes, inDimension});
int j = 0;
for (int i = 0; i < nRes; i++){
while (mask.getInt(j) == 0 && j < mask.length() - 1) {
j += 1;
}
if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j));
j += 1;
}
return res;
}
@Override
public INDArray search(INDArray query, double maxRange) {
if (maxRange < 0)
throw new IllegalArgumentException("ANN search should have a positive maximum search radius");
INDArray bucketData = bucketData(query);
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
INDArray shuffleIndexes = idxs[0];
INDArray sortedDistances = idxs[1];
int accepted = 0;
while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1;
INDArray res = Nd4j.create(new int[] {accepted, inDimension});
for(int i = 0; i < accepted; i++){
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
}
return res;
}
@Override
public INDArray search(INDArray query, int k) {
if (k < 1)
throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor");
INDArray bucketData = bucketData(query);
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
INDArray shuffleIndexes = idxs[0];
INDArray sortedDistances = idxs[1];
val accepted = Math.min(k, sortedDistances.shape()[1]);
INDArray res = Nd4j.create(accepted, inDimension);
for(int i = 0; i < accepted; i++){
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
}
return res;
}
}

View File

@ -1,38 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.optimisation;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import java.io.Serializable;
@Data
@NoArgsConstructor(access = AccessLevel.PROTECTED)
@AllArgsConstructor
public class ClusteringOptimization implements Serializable {
private ClusteringOptimizationType type;
private double value;
}

View File

@ -1,28 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.optimisation;
/**
*
*/
public enum ClusteringOptimizationType {
MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT
}

View File

@ -1,115 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.quadtree;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
public class Cell implements Serializable {
private double x, y, hw, hh;
public Cell(double x, double y, double hw, double hh) {
this.x = x;
this.y = y;
this.hw = hw;
this.hh = hh;
}
/**
* Whether the given point is contained
* within this cell
* @param point the point to check
* @return true if the point is contained, false otherwise
*/
public boolean containsPoint(INDArray point) {
double first = point.getDouble(0), second = point.getDouble(1);
return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (!(o instanceof Cell))
return false;
Cell cell = (Cell) o;
if (Double.compare(cell.hh, hh) != 0)
return false;
if (Double.compare(cell.hw, hw) != 0)
return false;
if (Double.compare(cell.x, x) != 0)
return false;
return Double.compare(cell.y, y) == 0;
}
@Override
public int hashCode() {
int result;
long temp;
temp = Double.doubleToLongBits(x);
result = (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(y);
result = 31 * result + (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(hw);
result = 31 * result + (int) (temp ^ (temp >>> 32));
temp = Double.doubleToLongBits(hh);
result = 31 * result + (int) (temp ^ (temp >>> 32));
return result;
}
public double getX() {
return x;
}
public void setX(double x) {
this.x = x;
}
public double getY() {
return y;
}
public void setY(double y) {
this.y = y;
}
public double getHw() {
return hw;
}
public void setHw(double hw) {
this.hw = hw;
}
public double getHh() {
return hh;
}
public void setHh(double hh) {
this.hh = hh;
}
}

View File

@ -1,383 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.quadtree;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import org.apache.commons.math3.util.FastMath;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import static java.lang.Math.max;
public class QuadTree implements Serializable {
private QuadTree parent, northWest, northEast, southWest, southEast;
private boolean isLeaf = true;
private int size, cumSize;
private Cell boundary;
static final int QT_NO_DIMS = 2;
static final int QT_NODE_CAPACITY = 1;
private INDArray buf = Nd4j.create(QT_NO_DIMS);
private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS);
private int[] index = new int[QT_NODE_CAPACITY];
/**
* Pass in a matrix
* @param data
*/
public QuadTree(INDArray data) {
INDArray meanY = data.mean(0);
INDArray minY = data.min(0);
INDArray maxY = data.max(0);
init(data, meanY.getDouble(0), meanY.getDouble(1),
max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0))
+ Nd4j.EPS_THRESHOLD,
max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1))
+ Nd4j.EPS_THRESHOLD);
fill();
}
public QuadTree(QuadTree parent, INDArray data, Cell boundary) {
this.parent = parent;
this.boundary = boundary;
this.data = data;
}
public QuadTree(Cell boundary) {
this.boundary = boundary;
}
private void init(INDArray data, double x, double y, double hw, double hh) {
boundary = new Cell(x, y, hw, hh);
this.data = data;
}
private void fill() {
for (int i = 0; i < data.rows(); i++)
insert(i);
}
/**
* Returns the cell of this element
*
* @param coordinates
* @return
*/
protected QuadTree findIndex(INDArray coordinates) {
// Compute the sector for the coordinates
boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2));
boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2));
// top left
QuadTree index = getNorthWest();
if (left) {
// left side
if (!top) {
// bottom left
index = getSouthWest();
}
} else {
// right side
if (top) {
// top right
index = getNorthEast();
} else {
// bottom right
index = getSouthEast();
}
}
return index;
}
/**
* Insert an index of the data in to the tree
* @param newIndex the index to insert in to the tree
* @return whether the index was inserted or not
*/
public boolean insert(int newIndex) {
// Ignore objects which do not belong in this quad tree
INDArray point = data.slice(newIndex);
if (!boundary.containsPoint(point))
return false;
cumSize++;
double mult1 = (double) (cumSize - 1) / (double) cumSize;
double mult2 = 1.0 / (double) cumSize;
centerOfMass.muli(mult1);
centerOfMass.addi(point.mul(mult2));
// If there is space in this quad tree and it is a leaf, add the object here
if (isLeaf() && size < QT_NODE_CAPACITY) {
index[size] = newIndex;
size++;
return true;
}
//duplicate point
if (size > 0) {
for (int i = 0; i < size; i++) {
INDArray compPoint = data.slice(index[i]);
if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1))
return true;
}
}
// If this Node has already been subdivided just add the elements to the
// appropriate cell
if (!isLeaf()) {
QuadTree index = findIndex(point);
index.insert(newIndex);
return true;
}
if (isLeaf())
subDivide();
boolean ret = insertIntoOneOf(newIndex);
return ret;
}
private boolean insertIntoOneOf(int index) {
boolean success = false;
success = northWest.insert(index);
if (!success)
success = northEast.insert(index);
if (!success)
success = southWest.insert(index);
if (!success)
success = southEast.insert(index);
return success;
}
/**
* Returns whether the tree is consistent or not
* @return whether the tree is consistent or not
*/
public boolean isCorrect() {
for (int n = 0; n < size; n++) {
INDArray point = data.slice(index[n]);
if (!boundary.containsPoint(point))
return false;
}
return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect()
&& southEast.isCorrect();
}
/**
* Create four children
* which fully divide this cell
* into four quads of equal area
*/
public void subDivide() {
northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
}
/**
* Compute non edge forces using barnes hut
* @param pointIndex
* @param theta
* @param negativeForce
* @param sumQ
*/
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
// Make sure that we spend no time on empty nodes or self-interactions
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
return;
// Compute distance between point and center-of-mass
buf.assign(data.slice(pointIndex)).subi(centerOfMass);
double D = Nd4j.getBlasWrapper().dot(buf, buf);
// Check whether we can use this node as a "summary"
if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) {
// Compute and add t-SNE force between point and current node
double Q = 1.0 / (1.0 + D);
double mult = cumSize * Q;
sumQ.addAndGet(mult);
mult *= Q;
negativeForce.addi(buf.mul(mult));
} else {
// Recursively apply Barnes-Hut to children
northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
}
}
/**
*
* @param rowP a vector
* @param colP
* @param valP
* @param N
* @param posF
*/
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
if (!rowP.isVector())
throw new IllegalArgumentException("RowP must be a vector");
// Loop over all edges in the graph
double D;
for (int n = 0; n < N; n++) {
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
// Compute pairwise distance and Q-value
buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i)));
D = Nd4j.getBlasWrapper().dot(buf, buf);
D = valP.getDouble(i) / D;
// Sum positive force
posF.slice(n).addi(buf.mul(D));
}
}
}
/**
* The depth of the node
* @return the depth of the node
*/
public int depth() {
if (isLeaf())
return 1;
return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth()));
}
public INDArray getCenterOfMass() {
return centerOfMass;
}
public void setCenterOfMass(INDArray centerOfMass) {
this.centerOfMass = centerOfMass;
}
public QuadTree getParent() {
return parent;
}
public void setParent(QuadTree parent) {
this.parent = parent;
}
public QuadTree getNorthWest() {
return northWest;
}
public void setNorthWest(QuadTree northWest) {
this.northWest = northWest;
}
public QuadTree getNorthEast() {
return northEast;
}
public void setNorthEast(QuadTree northEast) {
this.northEast = northEast;
}
public QuadTree getSouthWest() {
return southWest;
}
public void setSouthWest(QuadTree southWest) {
this.southWest = southWest;
}
public QuadTree getSouthEast() {
return southEast;
}
public void setSouthEast(QuadTree southEast) {
this.southEast = southEast;
}
public boolean isLeaf() {
return isLeaf;
}
public void setLeaf(boolean isLeaf) {
this.isLeaf = isLeaf;
}
public int getSize() {
return size;
}
public void setSize(int size) {
this.size = size;
}
public int getCumSize() {
return cumSize;
}
public void setCumSize(int cumSize) {
this.cumSize = cumSize;
}
public Cell getBoundary() {
return boundary;
}
public void setBoundary(Cell boundary) {
this.boundary = boundary;
}
}

View File

@ -1,104 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.randomprojection;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.List;
/**
*
*/
@Data
public class RPForest {
private int numTrees;
private List<RPTree> trees;
private INDArray data;
private int maxSize = 1000;
private String similarityFunction;
/**
* Create the rp forest with the specified number of trees
* @param numTrees the number of trees in the forest
* @param maxSize the max size of each tree
* @param similarityFunction the distance function to use
*/
public RPForest(int numTrees,int maxSize,String similarityFunction) {
this.numTrees = numTrees;
this.maxSize = maxSize;
this.similarityFunction = similarityFunction;
trees = new ArrayList<>(numTrees);
}
/**
* Build the trees from the given dataset
* @param x the input dataset (should be a 2d matrix)
*/
public void fit(INDArray x) {
this.data = x;
for(int i = 0; i < numTrees; i++) {
RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction);
tree.buildTree(x);
trees.add(tree);
}
}
/**
* Get all candidates relative to a specific datapoint.
* @param input
* @return
*/
public INDArray getAllCandidates(INDArray input) {
return RPUtils.getAllCandidates(input,trees,similarityFunction);
}
/**
* Query results up to length n
* nearest neighbors
* @param toQuery the query item
* @param n the number of nearest neighbors for the given data point
* @return the indices for the nearest neighbors
*/
public INDArray queryAll(INDArray toQuery,int n) {
return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction);
}
/**
* Query all with the distances
* sorted by index
* @param query the query vector
* @param numResults the number of results to return
* @return a list of samples
*/
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction);
}
}

View File

@ -1,57 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.randomprojection;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
@Data
public class RPHyperPlanes {
private int dim;
private INDArray wholeHyperPlane;
public RPHyperPlanes(int dim) {
this.dim = dim;
}
public INDArray getHyperPlaneAt(int depth) {
if(wholeHyperPlane.isVector())
return wholeHyperPlane;
return wholeHyperPlane.slice(depth);
}
/**
* Add a new random element to the hyper plane.
*/
public void addRandomHyperPlane() {
INDArray newPlane = Nd4j.randn(new int[] {1,dim});
newPlane.divi(newPlane.normmaxNumber());
if(wholeHyperPlane == null)
wholeHyperPlane = newPlane;
else {
wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane);
}
}
}

View File

@ -1,48 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.randomprojection;
import lombok.Data;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Future;
@Data
public class RPNode {
private int depth;
private RPNode left,right;
private Future<RPNode> leftFuture,rightFuture;
private List<Integer> indices;
private double median;
private RPTree tree;
public RPNode(RPTree tree,int depth) {
this.depth = depth;
this.tree = tree;
indices = new ArrayList<>();
}
}

View File

@ -1,130 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.randomprojection;
import lombok.Builder;
import lombok.Data;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.common.primitives.Pair;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ExecutorService;
@Data
public class RPTree {
private RPNode root;
private RPHyperPlanes rpHyperPlanes;
private int dim;
//also knows as leave size
private int maxSize;
private INDArray X;
private String similarityFunction = "euclidean";
private WorkspaceConfiguration workspaceConfiguration;
private ExecutorService searchExecutor;
private int searchWorkers;
/**
*
* @param dim the dimension of the vectors
* @param maxSize the max size of the leaves
*
*/
@Builder
public RPTree(int dim, int maxSize,String similarityFunction) {
this.dim = dim;
this.maxSize = maxSize;
rpHyperPlanes = new RPHyperPlanes(dim);
root = new RPNode(this,0);
this.similarityFunction = similarityFunction;
workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE).build();
}
/**
*
* @param dim the dimension of the vectors
* @param maxSize the max size of the leaves
*
*/
public RPTree(int dim, int maxSize) {
this(dim,maxSize,"euclidean");
}
/**
* Build the tree with the given input data
* @param x
*/
public void buildTree(INDArray x) {
this.X = x;
for(int i = 0; i < x.rows(); i++) {
root.getIndices().add(i);
}
RPUtils.buildTree(this,root,rpHyperPlanes,
x,maxSize,0,similarityFunction);
}
public void addNodeAtIndex(int idx,INDArray toAdd) {
RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction);
query.getIndices().add(idx);
}
public List<RPNode> getLeaves() {
List<RPNode> nodes = new ArrayList<>();
RPUtils.scanForLeaves(nodes,getRoot());
return nodes;
}
/**
* Query all with the distances
* sorted by index
* @param query the query vector
* @param numResults the number of results to return
* @return a list of samples
*/
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction);
}
public INDArray query(INDArray query,int numResults) {
return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction);
}
public List<Integer> getCandidates(INDArray target) {
return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction);
}
}

View File

@ -1,481 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.randomprojection;
import org.nd4j.shade.guava.primitives.Doubles;
import lombok.val;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.ReduceOp;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import java.util.*;
public class RPUtils {
private static ThreadLocal<Map<String,DifferentialFunction>> functionInstances = new ThreadLocal<>();
public static <T extends DifferentialFunction> DifferentialFunction getOp(String name,
INDArray x,
INDArray y,
INDArray result) {
Map<String,DifferentialFunction> ops = functionInstances.get();
if(ops == null) {
ops = new HashMap<>();
functionInstances.set(ops);
}
boolean allDistances = x.length() != y.length();
switch(name) {
case "cosinedistance":
if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances);
ops.put(name,cosineDistance);
return cosineDistance;
}
else {
CosineDistance cosineDistance = (CosineDistance) ops.get(name);
return cosineDistance;
}
case "cosinesimilarity":
if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) {
CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances);
ops.put(name,cosineSimilarity);
return cosineSimilarity;
}
else {
CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name);
cosineSimilarity.setX(x);
cosineSimilarity.setY(y);
cosineSimilarity.setZ(result);
return cosineSimilarity;
}
case "manhattan":
if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances);
ops.put(name,manhattanDistance);
return manhattanDistance;
}
else {
ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name);
manhattanDistance.setX(x);
manhattanDistance.setY(y);
manhattanDistance.setZ(result);
return manhattanDistance;
}
case "jaccard":
if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances);
ops.put(name,jaccardDistance);
return jaccardDistance;
}
else {
JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name);
jaccardDistance.setX(x);
jaccardDistance.setY(y);
jaccardDistance.setZ(result);
return jaccardDistance;
}
case "hamming":
if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances);
ops.put(name,hammingDistance);
return hammingDistance;
}
else {
HammingDistance hammingDistance = (HammingDistance) ops.get(name);
hammingDistance.setX(x);
hammingDistance.setY(y);
hammingDistance.setZ(result);
return hammingDistance;
}
//euclidean
default:
if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances);
ops.put(name,euclideanDistance);
return euclideanDistance;
}
else {
EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name);
euclideanDistance.setX(x);
euclideanDistance.setY(y);
euclideanDistance.setZ(result);
return euclideanDistance;
}
}
}
/**
* Query all trees using the given input and data
* @param toQuery the query vector
* @param X the input data to query
* @param trees the trees to query
* @param n the number of results to search for
* @param similarityFunction the similarity function to use
* @return the indices (in order) in the ndarray
*/
public static List<Pair<Double,Integer>> queryAllWithDistances(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
if(trees.isEmpty()) {
throw new ND4JIllegalArgumentException("Trees is empty!");
}
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
int numReturns = Math.min(n,sortedCandidates.size());
List<Pair<Double,Integer>> ret = new ArrayList<>(numReturns);
for(int i = 0; i < numReturns; i++) {
ret.add(sortedCandidates.get(i));
}
return ret;
}
/**
* Query all trees using the given input and data
* @param toQuery the query vector
* @param X the input data to query
* @param trees the trees to query
* @param n the number of results to search for
* @param similarityFunction the similarity function to use
* @return the indices (in order) in the ndarray
*/
public static INDArray queryAll(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
if(trees.isEmpty()) {
throw new ND4JIllegalArgumentException("Trees is empty!");
}
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
int numReturns = Math.min(n,sortedCandidates.size());
INDArray result = Nd4j.create(numReturns);
for(int i = 0; i < numReturns; i++) {
result.putScalar(i,sortedCandidates.get(i).getSecond());
}
return result;
}
/**
* Get the sorted distances given the
* query vector, input data, given the list of possible search candidates
* @param x the query vector
* @param X the input data to use
* @param candidates the possible search candidates
* @param similarityFunction the similarity function to use
* @return the sorted distances
*/
public static List<Pair<Double,Integer>> sortCandidates(INDArray x,INDArray X,
List<Integer> candidates,
String similarityFunction) {
int prevIdx = -1;
List<Pair<Double,Integer>> ret = new ArrayList<>();
for(int i = 0; i < candidates.size(); i++) {
if(candidates.get(i) != prevIdx) {
ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i)));
}
prevIdx = i;
}
Collections.sort(ret, new Comparator<Pair<Double, Integer>>() {
@Override
public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) {
return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst());
}
});
return ret;
}
/**
* Get the search candidates as indices given the input
* and similarity function
* @param x the input data to search with
* @param trees the trees to search
* @param similarityFunction the function to use for similarity
* @return the list of indices as the search results
*/
public static INDArray getAllCandidates(INDArray x,List<RPTree> trees,String similarityFunction) {
List<Integer> candidates = getCandidates(x,trees,similarityFunction);
Collections.sort(candidates);
int prevIdx = -1;
int idxCount = 0;
List<Pair<Integer,Integer>> scores = new ArrayList<>();
for(int i = 0; i < candidates.size(); i++) {
if(candidates.get(i) == prevIdx) {
idxCount++;
}
else if(prevIdx != -1) {
scores.add(Pair.of(idxCount,prevIdx));
idxCount = 1;
}
prevIdx = i;
}
scores.add(Pair.of(idxCount,prevIdx));
INDArray arr = Nd4j.create(scores.size());
for(int i = 0; i < scores.size(); i++) {
arr.putScalar(i,scores.get(i).getSecond());
}
return arr;
}
/**
* Get the search candidates as indices given the input
* and similarity function
* @param x the input data to search with
* @param roots the trees to search
* @param similarityFunction the function to use for similarity
* @return the list of indices as the search results
*/
public static List<Integer> getCandidates(INDArray x,List<RPTree> roots,String similarityFunction) {
Set<Integer> ret = new LinkedHashSet<>();
for(RPTree tree : roots) {
RPNode root = tree.getRoot();
RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction);
ret.addAll(query.getIndices());
}
return new ArrayList<>(ret);
}
/**
* Query the tree starting from the given node
* using the given hyper plane and similarity function
* @param from the node to start from
* @param planes the hyper plane to query
* @param x the input data
* @param similarityFunction the similarity function to use
* @return the leaf node representing the given query from a
* search in the tree
*/
public static RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) {
if(from.getLeft() == null && from.getRight() == null) {
return from;
}
INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth());
double dist = computeDistance(similarityFunction,x,hyperPlane);
if(dist <= from.getMedian()) {
return query(from.getLeft(),planes,x,similarityFunction);
}
else {
return query(from.getRight(),planes,x,similarityFunction);
}
}
/**
* Compute the distance between 2 vectors
* given a function name. Valid function names:
* euclidean: euclidean distance
* cosinedistance: cosine distance
* cosine similarity: cosine similarity
* manhattan: manhattan distance
* jaccard: jaccard distance
* hamming: hamming distance
* @param function the function to use (default euclidean distance)
* @param x the first vector
* @param y the second vector
* @return the distance between the 2 vectors given the inputs
*/
public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) {
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
op.setDimensions(1);
Nd4j.getExecutioner().exec(op);
return op.z();
}
/**
/**
* Compute the distance between 2 vectors
* given a function name. Valid function names:
* euclidean: euclidean distance
* cosinedistance: cosine distance
* cosine similarity: cosine similarity
* manhattan: manhattan distance
* jaccard: jaccard distance
* hamming: hamming distance
* @param function the function to use (default euclidean distance)
* @param x the first vector
* @param y the second vector
* @return the distance between the 2 vectors given the inputs
*/
public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) {
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
Nd4j.getExecutioner().exec(op);
return op.z().getDouble(0);
}
/**
* Compute the distance between 2 vectors
* given a function name. Valid function names:
* euclidean: euclidean distance
* cosinedistance: cosine distance
* cosine similarity: cosine similarity
* manhattan: manhattan distance
* jaccard: jaccard distance
* hamming: hamming distance
* @param function the function to use (default euclidean distance)
* @param x the first vector
* @param y the second vector
* @return the distance between the 2 vectors given the inputs
*/
public static double computeDistance(String function,INDArray x,INDArray y) {
return computeDistance(function,x,y,Nd4j.scalar(0.0));
}
/**
* Initialize the tree given the input parameters
* @param tree the tree to initialize
* @param from the starting node
* @param planes the hyper planes to use (vector space for similarity)
* @param X the input data
* @param maxSize the max number of indices on a given leaf node
* @param depth the current depth of the tree
* @param similarityFunction the similarity function to use
*/
public static void buildTree(RPTree tree,
RPNode from,
RPHyperPlanes planes,
INDArray X,
int maxSize,
int depth,
String similarityFunction) {
if(from.getIndices().size() <= maxSize) {
//slimNode
slimNode(from);
return;
}
List<Double> distances = new ArrayList<>();
RPNode left = new RPNode(tree,depth + 1);
RPNode right = new RPNode(tree,depth + 1);
if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) {
planes.addRandomHyperPlane();
}
INDArray hyperPlane = planes.getHyperPlaneAt(depth);
for(int i = 0; i < from.getIndices().size(); i++) {
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
distances.add(cosineSim);
}
Collections.sort(distances);
from.setMedian(distances.get(distances.size() / 2));
for(int i = 0; i < from.getIndices().size(); i++) {
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
if(cosineSim <= from.getMedian()) {
left.getIndices().add(from.getIndices().get(i));
}
else {
right.getIndices().add(from.getIndices().get(i));
}
}
//failed split
if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) {
slimNode(from);
return;
}
from.setLeft(left);
from.setRight(right);
slimNode(from);
buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction);
buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction);
}
/**
* Scan for leaves accumulating
* the nodes in the passed in list
* @param nodes the nodes so far
* @param scan the tree to scan
*/
public static void scanForLeaves(List<RPNode> nodes,RPTree scan) {
scanForLeaves(nodes,scan.getRoot());
}
/**
* Scan for leaves accumulating
* the nodes in the passed in list
* @param nodes the nodes so far
*/
public static void scanForLeaves(List<RPNode> nodes,RPNode current) {
if(current.getLeft() == null && current.getRight() == null)
nodes.add(current);
if(current.getLeft() != null)
scanForLeaves(nodes,current.getLeft());
if(current.getRight() != null)
scanForLeaves(nodes,current.getRight());
}
/**
* Prune indices from the given node
* when it's a leaf
* @param node the node to prune
*/
public static void slimNode(RPNode node) {
if(node.getRight() != null && node.getLeft() != null) {
node.getIndices().clear();
}
}
}

View File

@ -1,87 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.sptree;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
/**
* @author Adam Gibson
*/
public class Cell implements Serializable {
private int dimension;
private INDArray corner, width;
public Cell(int dimension) {
this.dimension = dimension;
}
public double corner(int d) {
return corner.getDouble(d);
}
public double width(int d) {
return width.getDouble(d);
}
public void setCorner(int d, double corner) {
this.corner.putScalar(d, corner);
}
public void setWidth(int d, double width) {
this.width.putScalar(d, width);
}
public void setWidth(INDArray width) {
this.width = width;
}
public void setCorner(INDArray corner) {
this.corner = corner;
}
public boolean contains(INDArray point) {
INDArray cornerMinusWidth = corner.sub(width);
INDArray cornerPlusWidth = corner.add(width);
for (int d = 0; d < dimension; d++) {
double pointD = point.getDouble(d);
if (cornerMinusWidth.getDouble(d) > pointD)
return false;
if (cornerPlusWidth.getDouble(d) < pointD)
return false;
}
return true;
}
public INDArray width() {
return width;
}
public INDArray corner() {
return corner;
}
}

View File

@ -1,95 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.sptree;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
@Data
public class DataPoint implements Serializable {
private int index;
private INDArray point;
private long d;
private String functionName;
private boolean invert = false;
public DataPoint(int index, INDArray point, boolean invert) {
this(index, point, "euclidean");
this.invert = invert;
}
public DataPoint(int index, INDArray point, String functionName, boolean invert) {
this.index = index;
this.point = point;
this.functionName = functionName;
this.d = point.length();
this.invert = invert;
}
public DataPoint(int index, INDArray point) {
this(index, point, false);
}
public DataPoint(int index, INDArray point, String functionName) {
this(index, point, functionName, false);
}
/**
* Euclidean distance
* @param point the distance from this point to the given point
* @return the distance between the two points
*/
public float distance(DataPoint point) {
switch (functionName) {
case "euclidean":
float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
.getFinalResult().floatValue();
return invert ? -ret : ret;
case "cosinesimilarity":
float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point))
.getFinalResult().floatValue();
return invert ? -ret2 : ret2;
case "manhattan":
float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point))
.getFinalResult().floatValue();
return invert ? -ret3 : ret3;
case "dot":
float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point);
return invert ? -dotRet : dotRet;
default:
float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
.getFinalResult().floatValue();
return invert ? -ret4 : ret4;
}
}
}

View File

@ -1,83 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.sptree;
import java.io.Serializable;
/**
* @author Adam Gibson
*/
public class HeapItem implements Serializable, Comparable<HeapItem> {
private int index;
private double distance;
public HeapItem(int index, double distance) {
this.index = index;
this.distance = distance;
}
public int getIndex() {
return index;
}
public void setIndex(int index) {
this.index = index;
}
public double getDistance() {
return distance;
}
public void setDistance(double distance) {
this.distance = distance;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
HeapItem heapItem = (HeapItem) o;
if (index != heapItem.index)
return false;
return Double.compare(heapItem.distance, distance) == 0;
}
@Override
public int hashCode() {
int result;
long temp;
result = index;
temp = Double.doubleToLongBits(distance);
result = 31 * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public int compareTo(HeapItem o) {
return distance < o.distance ? 1 : 0;
}
}

View File

@ -1,72 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.sptree;
import lombok.Data;
import org.nd4j.linalg.api.ndarray.INDArray;
import java.io.Serializable;
@Data
public class HeapObject implements Serializable, Comparable<HeapObject> {
private int index;
private INDArray point;
private double distance;
public HeapObject(int index, INDArray point, double distance) {
this.index = index;
this.point = point;
this.distance = distance;
}
@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
HeapObject heapObject = (HeapObject) o;
if (!point.equals(heapObject.point))
return false;
return Double.compare(heapObject.distance, distance) == 0;
}
@Override
public int hashCode() {
int result;
long temp;
result = index;
temp = Double.doubleToLongBits(distance);
result = 31 * result + (int) (temp ^ (temp >>> 32));
return result;
}
@Override
public int compareTo(HeapObject o) {
return distance < o.distance ? 1 : 0;
}
}

View File

@ -1,425 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.sptree;
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
import lombok.val;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
/**
* @author Adam Gibson
*/
public class SpTree implements Serializable {
public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL";
private int D;
private INDArray data;
public final static int NODE_RATIO = 8000;
private int N;
private int size;
private int cumSize;
private Cell boundary;
private INDArray centerOfMass;
private SpTree parent;
private int[] index;
private int nodeCapacity;
private int numChildren = 2;
private boolean isLeaf = true;
private Collection<INDArray> indices;
private SpTree[] children;
private static Logger log = LoggerFactory.getLogger(SpTree.class);
private String similarityFunction = Distance.EUCLIDEAN.toString();
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
String similarityFunction) {
init(parent, data, corner, width, indices, similarityFunction);
}
public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
this.indices = indices;
this.N = data.rows();
this.D = data.columns();
this.similarityFunction = similarityFunction;
data = data.dup();
INDArray meanY = data.mean(0);
INDArray minY = data.min(0);
INDArray maxY = data.max(0);
INDArray width = Nd4j.create(data.dataType(), meanY.shape());
for (int i = 0; i < width.length(); i++) {
width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i),
meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD);
}
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
init(null, data, meanY, width, indices, similarityFunction);
fill(N);
}
}
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
this(parent, data, corner, width, indices, "euclidean");
}
public SpTree(INDArray data, Collection<INDArray> indices) {
this(data, indices, "euclidean");
}
public SpTree(INDArray data) {
this(data, new ArrayList<INDArray>());
}
public MemoryWorkspace workspace() {
return null;
}
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
String similarityFunction) {
this.parent = parent;
D = data.columns();
N = data.rows();
this.similarityFunction = similarityFunction;
nodeCapacity = N % NODE_RATIO;
index = new int[nodeCapacity];
for (int d = 1; d < this.D; d++)
numChildren *= 2;
this.indices = indices;
isLeaf = true;
size = 0;
cumSize = 0;
children = new SpTree[numChildren];
this.data = data;
boundary = new Cell(D);
boundary.setCorner(corner.dup());
boundary.setWidth(width.dup());
centerOfMass = Nd4j.create(data.dataType(), D);
}
private boolean insert(int index) {
/*MemoryWorkspace workspace =
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationExternal,
workspaceExternal);
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
INDArray point = data.slice(index);
/*boolean contains = false;
SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains);
Nd4j.getExecutioner().exec(op);
op.getOutputArgument(0).getScalar(0);
if (!contains) return false;*/
if (!boundary.contains(point))
return false;
cumSize++;
double mult1 = (double) (cumSize - 1) / (double) cumSize;
double mult2 = 1.0 / (double) cumSize;
centerOfMass.muli(mult1);
centerOfMass.addi(point.mul(mult2));
// If there is space in this quad tree and it is a leaf, add the object here
if (isLeaf() && size < nodeCapacity) {
this.index[size] = index;
indices.add(point);
size++;
return true;
}
for (int i = 0; i < size; i++) {
INDArray compPoint = data.slice(this.index[i]);
if (compPoint.equals(point))
return true;
}
if (isLeaf())
subDivide();
// Find out where the point can be inserted
for (int i = 0; i < numChildren; i++) {
if (children[i].insert(index))
return true;
}
throw new IllegalStateException("Shouldn't reach this state");
}
}
/**
* Subdivide the node in to
* 4 children
*/
public void subDivide() {
/*MemoryWorkspace workspace =
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationExternal,
workspaceExternal);
try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{
INDArray newCorner = Nd4j.create(data.dataType(), D);
INDArray newWidth = Nd4j.create(data.dataType(), D);
for (int i = 0; i < numChildren; i++) {
int div = 1;
for (int d = 0; d < D; d++) {
newWidth.putScalar(d, .5 * boundary.width(d));
if ((i / div) % 2 == 1)
newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d));
else
newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d));
div *= 2;
}
children[i] = new SpTree(this, data, newCorner, newWidth, indices);
}
// Move existing points to correct children
for (int i = 0; i < size; i++) {
boolean success = false;
for (int j = 0; j < this.numChildren; j++)
if (!success)
success = children[j].insert(index[i]);
index[i] = -1;
}
// Empty parent node
size = 0;
isLeaf = false;
}
}
/**
* Compute non edge forces using barnes hut
* @param pointIndex
* @param theta
* @param negativeForce
* @param sumQ
*/
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
// Make sure that we spend no time on empty nodes or self-interactions
INDArray buf = Nd4j.create(data.dataType(), this.D);
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
return;
/* MemoryWorkspace workspace =
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationExternal,
workspaceExternal);
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
// Compute distance between point and center-of-mass
data.slice(pointIndex).subi(centerOfMass, buf);
double D = Nd4j.getBlasWrapper().dot(buf, buf);
// Check whether we can use this node as a "summary"
double maxWidth = boundary.width().maxNumber().doubleValue();
// Check whether we can use this node as a "summary"
if (isLeaf() || maxWidth / Math.sqrt(D) < theta) {
// Compute and add t-SNE force between point and current node
double Q = 1.0 / (1.0 + D);
double mult = cumSize * Q;
sumQ.addAndGet(mult);
mult *= Q;
negativeForce.addi(buf.mul(mult));
} else {
// Recursively apply Barnes-Hut to children
for (int i = 0; i < numChildren; i++) {
children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
}
}
}
}
/**
*
* Compute edge forces using barnes hut
* @param rowP a vector
* @param colP
* @param valP
* @param N the number of elements
* @param posF the positive force
*/
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
if (!rowP.isVector())
throw new IllegalArgumentException("RowP must be a vector");
// Loop over all edges in the graph
// just execute native op
Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF));
/*
INDArray buf = Nd4j.create(data.dataType(), this.D);
double D;
for (int n = 0; n < N; n++) {
INDArray slice = data.slice(n);
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
// Compute pairwise distance and Q-value
slice.subi(data.slice(colP.getInt(i)), buf);
D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf);
D = valP.getDouble(i) / D;
// Sum positive force
posF.slice(n).addi(buf.muli(D));
}
}
*/
}
public boolean isLeaf() {
return isLeaf;
}
/**
* Verifies the structure of the tree (does bounds checking on each node)
* @return true if the structure of the tree
* is correct.
*/
public boolean isCorrect() {
/*MemoryWorkspace workspace =
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
workspaceConfigurationExternal,
workspaceExternal);
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
for (int n = 0; n < size; n++) {
INDArray point = data.slice(index[n]);
if (!boundary.contains(point))
return false;
}
if (!isLeaf()) {
boolean correct = true;
for (int i = 0; i < numChildren; i++)
correct = correct && children[i].isCorrect();
return correct;
}
return true;
}
}
/**
* The depth of the node
* @return the depth of the node
*/
public int depth() {
if (isLeaf())
return 1;
int depth = 1;
int maxChildDepth = 0;
for (int i = 0; i < numChildren; i++) {
maxChildDepth = Math.max(maxChildDepth, children[0].depth());
}
return depth + maxChildDepth;
}
private void fill(int n) {
if (indices.isEmpty() && parent == null)
for (int i = 0; i < n; i++) {
log.trace("Inserted " + i);
insert(i);
}
else
log.warn("Called fill already");
}
public SpTree[] getChildren() {
return children;
}
public int getD() {
return D;
}
public INDArray getCenterOfMass() {
return centerOfMass;
}
public Cell getBoundary() {
return boundary;
}
public int[] getIndex() {
return index;
}
public int getCumSize() {
return cumSize;
}
public void setCumSize(int cumSize) {
this.cumSize = cumSize;
}
public int getNumChildren() {
return numChildren;
}
public void setNumChildren(int numChildren) {
this.numChildren = numChildren;
}
}

View File

@ -1,117 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.strategy;
import lombok.*;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
import java.io.Serializable;
@AllArgsConstructor(access = AccessLevel.PROTECTED)
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable {
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected ClusteringStrategyType type;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected Integer initialClusterCount;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected ClusteringAlgorithmCondition optimizationPhaseCondition;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected ClusteringAlgorithmCondition terminationCondition;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected boolean inverse;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected Distance distanceFunction;
@Getter(AccessLevel.PUBLIC)
@Setter(AccessLevel.PROTECTED)
protected boolean allowEmptyClusters;
public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction,
boolean allowEmptyClusters, boolean inverse) {
this.type = type;
this.initialClusterCount = initialClusterCount;
this.distanceFunction = distanceFunction;
this.allowEmptyClusters = allowEmptyClusters;
this.inverse = inverse;
}
public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount,
Distance distanceFunction, boolean inverse) {
this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse);
}
/**
*
* @param maxIterationCount
* @return
*/
public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) {
setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount));
return this;
}
/**
*
* @param rate
* @return
*/
public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) {
setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate));
return this;
}
/**
* @return
*/
@Override
public boolean inverseDistanceCalculation() {
return inverse;
}
/**
*
* @param type
* @return
*/
public boolean isStrategyOfType(ClusteringStrategyType type) {
return type.equals(this.type);
}
/**
*
* @return
*/
public Integer getInitialClusterCount() {
return initialClusterCount;
}
}

View File

@ -1,102 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.strategy;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
import org.deeplearning4j.clustering.iteration.IterationHistory;
/**
*
*/
public interface ClusteringStrategy {
/**
*
* @return
*/
boolean inverseDistanceCalculation();
/**
*
* @return
*/
ClusteringStrategyType getType();
/**
*
* @param type
* @return
*/
boolean isStrategyOfType(ClusteringStrategyType type);
/**
*
* @return
*/
Integer getInitialClusterCount();
/**
*
* @return
*/
Distance getDistanceFunction();
/**
*
* @return
*/
boolean isAllowEmptyClusters();
/**
*
* @return
*/
ClusteringAlgorithmCondition getTerminationCondition();
/**
*
* @return
*/
boolean isOptimizationDefined();
/**
*
* @param iterationHistory
* @return
*/
boolean isOptimizationApplicableNow(IterationHistory iterationHistory);
/**
*
* @param maxIterationCount
* @return
*/
BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount);
/**
*
* @param rate
* @return
*/
BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate);
}

View File

@ -1,25 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.strategy;
public enum ClusteringStrategyType {
FIXED_CLUSTER_COUNT, OPTIMIZATION
}

View File

@ -1,68 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.strategy;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.iteration.IterationHistory;
/**
*
*/
@NoArgsConstructor(access = AccessLevel.PROTECTED)
public class FixedClusterCountStrategy extends BaseClusteringStrategy {
protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction,
boolean allowEmptyClusters, boolean inverse) {
super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters,
inverse);
}
/**
*
* @param clusterCount
* @param distanceFunction
* @param inverse
* @return
*/
public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) {
return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse);
}
/**
* @return
*/
@Override
public boolean inverseDistanceCalculation() {
return inverse;
}
public boolean isOptimizationDefined() {
return false;
}
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
return false;
}
}

View File

@ -1,82 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.strategy;
import org.deeplearning4j.clustering.algorithm.Distance;
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
import org.deeplearning4j.clustering.iteration.IterationHistory;
import org.deeplearning4j.clustering.optimisation.ClusteringOptimization;
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
public class OptimisationStrategy extends BaseClusteringStrategy {
public static int defaultIterationCount = 100;
private ClusteringOptimization clusteringOptimisation;
private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition;
protected OptimisationStrategy() {
super();
}
protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) {
super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false);
}
public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) {
return new OptimisationStrategy(initialClusterCount, distanceFunction);
}
public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) {
clusteringOptimisation = new ClusteringOptimization(type, value);
return this;
}
public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) {
clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value);
return this;
}
public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) {
clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate);
return this;
}
public double getClusteringOptimizationValue() {
return clusteringOptimisation.getValue();
}
public boolean isClusteringOptimizationType(ClusteringOptimizationType type) {
return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type);
}
public boolean isOptimizationDefined() {
return clusteringOptimisation != null;
}
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
return clusteringOptimisationApplicationCondition != null
&& clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory);
}
}

View File

@ -1,74 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.util;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.List;
import java.util.concurrent.*;
public class MultiThreadUtils {
private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class);
private static ExecutorService instance;
private MultiThreadUtils() {}
public static synchronized ExecutorService newExecutorService() {
int nThreads = Runtime.getRuntime().availableProcessors();
return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(),
new ThreadFactory() {
@Override
public Thread newThread(Runnable r) {
Thread t = Executors.defaultThreadFactory().newThread(r);
t.setDaemon(true);
return t;
}
});
}
public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) {
int tasksCount = tasks.size();
final CountDownLatch latch = new CountDownLatch(tasksCount);
for (int i = 0; i < tasksCount; i++) {
final int taskIdx = i;
executorService.execute(new Runnable() {
public void run() {
try {
tasks.get(taskIdx).run();
} catch (Throwable e) {
log.info("Unchecked exception thrown by task", e);
} finally {
latch.countDown();
}
}
});
}
try {
latch.await();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}

View File

@ -1,61 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.util;
import java.util.Collection;
import java.util.HashSet;
import java.util.Set;
public class SetUtils {
private SetUtils() {}
// Set specific operations
public static <T> Set<T> intersection(Collection<T> parentCollection, Collection<T> removeFromCollection) {
Set<T> results = new HashSet<>(parentCollection);
results.retainAll(removeFromCollection);
return results;
}
public static <T> boolean intersectionP(Set<? extends T> s1, Set<? extends T> s2) {
for (T elt : s1) {
if (s2.contains(elt))
return true;
}
return false;
}
public static <T> Set<T> union(Set<? extends T> s1, Set<? extends T> s2) {
Set<T> s3 = new HashSet<>(s1);
s3.addAll(s2);
return s3;
}
/** Return is s1 \ s2 */
public static <T> Set<T> difference(Collection<? extends T> s1, Collection<? extends T> s2) {
Set<T> s3 = new HashSet<>(s1);
s3.removeAll(s2);
return s3;
}
}

View File

@ -1,633 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.vptree;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.deeplearning4j.clustering.sptree.HeapObject;
import org.deeplearning4j.clustering.util.MathUtils;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.*;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce3.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
@Slf4j
@Builder
@AllArgsConstructor
public class VPTree implements Serializable {
private static final long serialVersionUID = 1L;
public static final String EUCLIDEAN = "euclidean";
private double tau;
@Getter
@Setter
private INDArray items;
private List<INDArray> itemsList;
private Node root;
private String similarityFunction;
@Getter
private boolean invert = false;
private transient ExecutorService executorService;
@Getter
private int workers = 1;
private AtomicInteger size = new AtomicInteger(0);
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
private WorkspaceConfiguration workspaceConfiguration;
protected VPTree() {
// method for serialization only
scalars = new ThreadLocal<>();
}
/**
*
* @param points
* @param invert
*/
public VPTree(INDArray points, boolean invert) {
this(points, "euclidean", 1, invert);
}
/**
*
* @param points
* @param invert
* @param workers number of parallel workers for tree building (increases memory requirements!)
*/
public VPTree(INDArray points, boolean invert, int workers) {
this(points, "euclidean", workers, invert);
}
/**
*
* @param items the items to use
* @param similarityFunction the similarity function to use
* @param invert whether to invert the distance (similarity functions have different min/max objectives)
*/
public VPTree(INDArray items, String similarityFunction, boolean invert) {
this.similarityFunction = similarityFunction;
this.invert = invert;
this.items = items;
root = buildFromPoints(items);
workers = 1;
}
/**
*
* @param items the items to use
* @param similarityFunction the similarity function to use
* @param workers number of parallel workers for tree building (increases memory requirements!)
* @param invert whether to invert the metric (different optimization objective)
*/
public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) {
this.workers = workers;
val list = new INDArray[items.size()];
// build list of INDArrays first
for (int i = 0; i < items.size(); i++)
list[i] = items.get(i).getPoint();
//this.items.putRow(i, items.get(i).getPoint());
// just stack them out with concat :)
this.items = Nd4j.pile(list);
this.invert = invert;
this.similarityFunction = similarityFunction;
root = buildFromPoints(this.items);
}
/**
*
* @param items
* @param similarityFunction
*/
public VPTree(INDArray items, String similarityFunction) {
this(items, similarityFunction, 1, false);
}
/**
*
* @param items
* @param similarityFunction
* @param workers number of parallel workers for tree building (increases memory requirements!)
* @param invert
*/
public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) {
this.similarityFunction = similarityFunction;
this.invert = invert;
this.items = items;
this.workers = workers;
root = buildFromPoints(items);
}
/**
*
* @param items
* @param similarityFunction
*/
public VPTree(List<DataPoint> items, String similarityFunction) {
this(items, similarityFunction, 1, false);
}
/**
*
* @param items
*/
public VPTree(INDArray items) {
this(items, EUCLIDEAN);
}
/**
*
* @param items
*/
public VPTree(List<DataPoint> items) {
this(items, EUCLIDEAN);
}
/**
* Create an ndarray
* from the datapoints
* @param data
* @return
*/
public static INDArray buildFromData(List<DataPoint> data) {
INDArray ret = Nd4j.create(data.size(), data.get(0).getD());
for (int i = 0; i < ret.slices(); i++)
ret.putSlice(i, data.get(i).getPoint());
return ret;
}
/**
*
* @param basePoint
* @param distancesArr
*/
public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) {
switch (similarityFunction) {
case "euclidean":
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1));
break;
case "cosinedistance":
Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1));
break;
case "cosinesimilarity":
Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1));
break;
case "manhattan":
Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1));
break;
case "dot":
Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1));
break;
case "jaccard":
Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1));
break;
case "hamming":
Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1));
break;
default:
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1));
break;
}
if (invert)
distancesArr.negi();
}
public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) {
calcDistancesRelativeTo(items, basePoint, distancesArr);
}
/**
* Euclidean distance
* @return the distance between the two points
*/
public double distance(INDArray arr1, INDArray arr2) {
if (scalars == null)
scalars = new ThreadLocal<>();
if (scalars.get() == null)
scalars.set(Nd4j.scalar(arr1.dataType(), 0.0));
switch (similarityFunction) {
case "jaccard":
double ret7 = Nd4j.getExecutioner()
.execAndReturn(new JaccardDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret7 : ret7;
case "hamming":
double ret8 = Nd4j.getExecutioner()
.execAndReturn(new HammingDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret8 : ret8;
case "euclidean":
double ret = Nd4j.getExecutioner()
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret : ret;
case "cosinesimilarity":
double ret2 = Nd4j.getExecutioner()
.execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret2 : ret2;
case "cosinedistance":
double ret6 = Nd4j.getExecutioner()
.execAndReturn(new CosineDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret6 : ret6;
case "manhattan":
double ret3 = Nd4j.getExecutioner()
.execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret3 : ret3;
case "dot":
double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2);
return invert ? -dotRet : dotRet;
default:
double ret4 = Nd4j.getExecutioner()
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
.getFinalResult().doubleValue();
return invert ? -ret4 : ret4;
}
}
protected class NodeBuilder implements Callable<Node> {
protected List<INDArray> list;
protected List<Integer> indices;
public NodeBuilder(List<INDArray> list, List<Integer> indices) {
this.list = list;
this.indices = indices;
}
@Override
public Node call() throws Exception {
return buildFromPoints(list, indices);
}
}
private Node buildFromPoints(List<INDArray> points, List<Integer> indices) {
Node ret = new Node(0, 0);
// nothing to sort here
if (points.size() == 1) {
ret.point = points.get(0);
ret.index = indices.get(0);
return ret;
}
// opening workspace, and creating it if that's the first call
/* MemoryWorkspace workspace =
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
INDArray items = Nd4j.vstack(points);
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint);
ret.point = basePoint;
ret.index = indices.get(randomPoint);
INDArray distancesArr = Nd4j.create(items.rows(), 1);
calcDistancesRelativeTo(items, basePoint, distancesArr);
double medianDistance = distancesArr.medianNumber().doubleValue();
ret.threshold = (float) medianDistance;
List<INDArray> leftPoints = new ArrayList<>();
List<Integer> leftIndices = new ArrayList<>();
List<INDArray> rightPoints = new ArrayList<>();
List<Integer> rightIndices = new ArrayList<>();
for (int i = 0; i < distancesArr.length(); i++) {
if (i == randomPoint)
continue;
if (distancesArr.getDouble(i) < medianDistance) {
leftPoints.add(points.get(i));
leftIndices.add(indices.get(i));
} else {
rightPoints.add(points.get(i));
rightIndices.add(indices.get(i));
}
}
// closing workspace
//workspace.notifyScopeLeft();
//log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes());
if (workers > 1) {
if (!leftPoints.isEmpty())
ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints);
if (!rightPoints.isEmpty())
ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices));
} else {
if (!leftPoints.isEmpty())
ret.left = buildFromPoints(leftPoints, leftIndices);
if (!rightPoints.isEmpty())
ret.right = buildFromPoints(rightPoints, rightIndices);
}
return ret;
}
private Node buildFromPoints(INDArray items) {
if (executorService == null && items == this.items && workers > 1) {
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
@Override
public Thread newThread(final Runnable r) {
Thread t = new Thread(new Runnable() {
@Override
public void run() {
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
r.run();
}
});
t.setDaemon(true);
t.setName("VPTree thread");
return t;
}
});
}
final Node ret = new Node(0, 0);
size.incrementAndGet();
/*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
.policySpill(SpillPolicy.REALLOCATE).build();
// opening workspace
MemoryWorkspace workspace =
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
INDArray basePoint = items.getRow(randomPoint, true);
INDArray distancesArr = Nd4j.create(items.rows(), 1);
ret.point = basePoint;
ret.index = randomPoint;
calcDistancesRelativeTo(items, basePoint, distancesArr);
double medianDistance = distancesArr.medianNumber().doubleValue();
ret.threshold = (float) medianDistance;
List<INDArray> leftPoints = new ArrayList<>();
List<Integer> leftIndices = new ArrayList<>();
List<INDArray> rightPoints = new ArrayList<>();
List<Integer> rightIndices = new ArrayList<>();
for (int i = 0; i < distancesArr.length(); i++) {
if (i == randomPoint)
continue;
if (distancesArr.getDouble(i) < medianDistance) {
leftPoints.add(items.getRow(i, true));
leftIndices.add(i);
} else {
rightPoints.add(items.getRow(i, true));
rightIndices.add(i);
}
}
// closing workspace
//workspace.notifyScopeLeft();
//workspace.destroyWorkspace(true);
if (!leftPoints.isEmpty())
ret.left = buildFromPoints(leftPoints, leftIndices);
if (!rightPoints.isEmpty())
ret.right = buildFromPoints(rightPoints, rightIndices);
// destroy once again
//workspace.destroyWorkspace(true);
if (ret.left != null)
ret.left.fetchFutures();
if (ret.right != null)
ret.right.fetchFutures();
if (executorService != null)
executorService.shutdown();
return ret;
}
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) {
search(target, k, results, distances, true);
}
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
boolean filterEqual) {
search(target, k, results, distances, filterEqual, false);
}
/**
*
* @param target
* @param k
* @param results
* @param distances
*/
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
boolean filterEqual, boolean dropEdge) {
if (items != null)
if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1)
throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", "
+ items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead");
k = Math.min(k, items.rows());
results.clear();
distances.clear();
PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator());
search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE);
while (!pq.isEmpty()) {
HeapObject ho = pq.peek();
results.add(new DataPoint(ho.getIndex(), ho.getPoint()));
distances.add(ho.getDistance());
pq.poll();
}
Collections.reverse(results);
Collections.reverse(distances);
if (dropEdge || results.size() > k) {
if (filterEqual && distances.get(0) == 0.0) {
results.remove(0);
distances.remove(0);
}
while (results.size() > k) {
results.remove(results.size() - 1);
distances.remove(distances.size() - 1);
}
}
}
/**
*
* @param node
* @param target
* @param k
* @param pq
*/
public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) {
if (node == null)
return;
double tau = cTau;
INDArray get = node.getPoint(); //items.getRow(node.getIndex());
double distance = distance(get, target);
if (distance < tau) {
if (pq.size() == k)
pq.poll();
pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance));
if (pq.size() == k)
tau = pq.peek().getDistance();
}
Node left = node.getLeft();
Node right = node.getRight();
if (left == null && right == null)
return;
if (distance < node.getThreshold()) {
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first
search(left, target, k, pq, tau);
}
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child
search(right, target, k, pq, tau);
}
} else {
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first
search(right, target, k, pq, tau);
}
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child
search(left, target, k, pq, tau);
}
}
}
protected class HeapObjectComparator implements Comparator<HeapObject> {
@Override
public int compare(HeapObject o1, HeapObject o2) {
return Double.compare(o2.getDistance(), o1.getDistance());
}
}
@Data
public static class Node implements Serializable {
private static final long serialVersionUID = 2L;
private int index;
private float threshold;
private Node left, right;
private INDArray point;
protected transient Future<Node> futureLeft;
protected transient Future<Node> futureRight;
public Node(int index, float threshold) {
this.index = index;
this.threshold = threshold;
}
public void fetchFutures() {
try {
if (futureLeft != null) {
/*while (!futureLeft.isDone())
Thread.sleep(100);*/
left = futureLeft.get();
}
if (futureRight != null) {
/*while (!futureRight.isDone())
Thread.sleep(100);*/
right = futureRight.get();
}
if (left != null)
left.fetchFutures();
if (right != null)
right.fetchFutures();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
}

View File

@ -1,79 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.vptree;
import lombok.Getter;
import org.deeplearning4j.clustering.sptree.DataPoint;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
public class VPTreeFillSearch {
private VPTree vpTree;
private int k;
@Getter
private List<DataPoint> results;
@Getter
private List<Double> distances;
private INDArray target;
public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) {
this.vpTree = vpTree;
this.k = k;
this.target = target;
}
public void search() {
results = new ArrayList<>();
distances = new ArrayList<>();
//initial search
//vpTree.search(target,k,results,distances);
//fill till there is k results
//by going down the list
// if(results.size() < k) {
INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1);
vpTree.calcDistancesRelativeTo(target, distancesArr);
INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert());
results.clear();
distances.clear();
if (vpTree.getItems().isVector()) {
for (int i = 0; i < k; i++) {
int idx = sortWithIndices[0].getInt(i);
results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx))));
distances.add(sortWithIndices[1].getDouble(idx));
}
} else {
for (int i = 0; i < k; i++) {
int idx = sortWithIndices[0].getInt(i);
results.add(new DataPoint(idx, vpTree.getItems().getRow(idx)));
//distances.add(sortWithIndices[1].getDouble(idx));
distances.add(sortWithIndices[1].getDouble(i));
}
}
}
}

View File

@ -1,21 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.vptree;

View File

@ -1,46 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.cluster;
import org.junit.Assert;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
import java.util.ArrayList;
import java.util.List;
public class ClusterSetTest {
@Test
public void testGetMostPopulatedClusters() {
ClusterSet clusterSet = new ClusterSet(false);
List<Cluster> clusters = new ArrayList<>();
for (int i = 0; i < 5; i++) {
Cluster cluster = new Cluster();
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
clusters.add(cluster);
}
clusterSet.setClusters(clusters);
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
for (int i = 0; i < 5; i++) {
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
}
}
}

View File

@ -1,422 +0,0 @@
/*
* ******************************************************************************
* *
* *
* * This program and the accompanying materials are made available under the
* * terms of the Apache License, Version 2.0 which is available at
* * https://www.apache.org/licenses/LICENSE-2.0.
* *
* * See the NOTICE file distributed with this work for additional
* * information regarding copyright ownership.
* * Unless required by applicable law or agreed to in writing, software
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* * License for the specific language governing permissions and limitations
* * under the License.
* *
* * SPDX-License-Identifier: Apache-2.0
* *****************************************************************************
*/
package org.deeplearning4j.clustering.kdtree;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.joda.time.Duration;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.common.primitives.Pair;
import org.nd4j.shade.guava.base.Stopwatch;
import org.nd4j.shade.guava.primitives.Doubles;
import org.nd4j.shade.guava.primitives.Floats;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static java.util.concurrent.TimeUnit.SECONDS;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
public class KDTreeTest extends BaseDL4JTest {
@Override
public long getTimeoutMilliseconds() {
return 120000L;
}
private KDTree kdTree;
@BeforeClass
public static void beforeClass(){
Nd4j.setDataType(DataType.FLOAT);
}
@Before
public void setUp() {
kdTree = new KDTree(2);
float[] data = new float[]{7,2};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{5,4};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{2,3};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{4,7};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{9,6};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{8,1};
kdTree.insert(Nd4j.createFromArray(data));
}
@Test
public void testTree() {
KDTree tree = new KDTree(2);
INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT);
INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT);
tree.insert(half);
tree.insert(one);
Pair<Double, INDArray> pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT));
assertEquals(half, pair.getValue());
}
@Test
public void testInsert() {
int elements = 10;
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
KDTree kdTree = new KDTree(digits.size());
List<List<Double>> lists = new ArrayList<>();
for (int i = 0; i < elements; i++) {
List<Double> thisList = new ArrayList<>(digits.size());
for (int k = 0; k < digits.size(); k++) {
thisList.add(digits.get(k) + i);
}
lists.add(thisList);
}
for (int i = 0; i < elements; i++) {
double[] features = Doubles.toArray(lists.get(i));
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
kdTree.insert(ind);
assertEquals(i + 1, kdTree.size());
}
}
@Test
public void testDelete() {
int elements = 10;
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
KDTree kdTree = new KDTree(digits.size());
List<List<Double>> lists = new ArrayList<>();
for (int i = 0; i < elements; i++) {
List<Double> thisList = new ArrayList<>(digits.size());
for (int k = 0; k < digits.size(); k++) {
thisList.add(digits.get(k) + i);
}
lists.add(thisList);
}
INDArray toDelete = Nd4j.empty(DataType.DOUBLE),
leafToDelete = Nd4j.empty(DataType.DOUBLE);
for (int i = 0; i < elements; i++) {
double[] features = Doubles.toArray(lists.get(i));
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
if (i == 1)
toDelete = ind;
if (i == elements - 1) {
leafToDelete = ind;
}
kdTree.insert(ind);
assertEquals(i + 1, kdTree.size());
}
kdTree.delete(toDelete);
assertEquals(9, kdTree.size());
kdTree.delete(leafToDelete);
assertEquals(8, kdTree.size());
}
@Test
public void testNN() {
int n = 10;
// make a KD-tree of dimension {#n}
KDTree kdTree = new KDTree(n);
for (int i = -1; i < n; i++) {
// Insert a unit vector along each dimension
List<Double> vec = new ArrayList<>(n);
// i = -1 ensures the origin is in the Tree
for (int k = 0; k < n; k++) {
vec.add((k == i) ? 1.0 : 0.0);
}
INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT);
kdTree.insert(indVec);
}
Random rand = new Random();
// random point in the Hypercube
List<Double> pt = new ArrayList(n);
for (int k = 0; k < n; k++) {
pt.add(rand.nextDouble());
}
Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT));
// Always true for points in the unitary hypercube
assertTrue(result.getKey() < Double.MAX_VALUE);
}
@Test
public void testKNN() {
int dimensions = 512;
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
// make a KD-tree of dimension {#dimensions}
Stopwatch stopwatch = Stopwatch.createStarted();
KDTree kdTree = new KDTree(dimensions);
for (int i = -1; i < vectorsNo; i++) {
// Insert a unit vector along each dimension
INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
kdTree.insert(indVec);
}
stopwatch.stop();
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
Random rand = new Random();
// random point in the Hypercube
List<Double> pt = new ArrayList(dimensions);
for (int k = 0; k < dimensions; k++) {
pt.add(rand.nextFloat() * 10.0);
}
stopwatch.reset();
stopwatch.start();
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
stopwatch.stop();
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
}
@Test
public void testKNN_Simple() {
int n = 2;
KDTree kdTree = new KDTree(n);
float[] data = new float[]{3,3};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{1,1};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{2,2};
kdTree.insert(Nd4j.createFromArray(data));
data = new float[]{0,0};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5);
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
}
@Test
public void testKNN_1() {
assertEquals(6, kdTree.size());
float[] data = new float[]{8,1};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
}
@Test
public void testKNN_2() {
float[] data = new float[]{8, 1};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
}
@Test
public void testKNN_3() {
float[] data = new float[]{2, 3};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
}
@Test
public void testKNN_4() {
float[] data = new float[]{2, 3};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
}
@Test
public void testKNN_5() {
float[] data = new float[]{2, 3};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
}
@Test
public void test_KNN_6() {
float[] data = new float[]{4, 6};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
}
@Test
public void test_KNN_7() {
float[] data = new float[]{4, 6};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
}
@Test
public void test_KNN_8() {
float[] data = new float[]{4, 6};
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
}
@Test
public void testNoDuplicates() {
int N = 100;
KDTree bigTree = new KDTree(2);
List<INDArray> points = new ArrayList<>();
for (int i = 0; i < N; ++i) {
double[] data = new double[]{i, i};
points.add(Nd4j.createFromArray(data));
}
for (int i = 0; i < N; ++i) {
bigTree.insert(points.get(i));
}
assertEquals(N, bigTree.size());
INDArray node = Nd4j.empty(DataType.DOUBLE);
for (int i = 0; i < N; ++i) {
node = bigTree.delete(node.isEmpty() ? points.get(i) : node);
}
assertEquals(0, bigTree.size());
}
@Ignore
@Test
public void performanceTest() {
int n = 2;
int num = 100000;
// make a KD-tree of dimension {#n}
long start = System.currentTimeMillis();
KDTree kdTree = new KDTree(n);
INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n);
for (int i = 0 ; i < num; ++i) {
kdTree.insert(inputArrray.getRow(i));
}
long end = System.currentTimeMillis();
Duration duration = new Duration(start, end);
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
List<Float> pt = new ArrayList(num);
for (int k = 0; k < n; k++) {
pt.add((float)(num / 2));
}
start = System.currentTimeMillis();
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
end = System.currentTimeMillis();
duration = new Duration(start, end);
long elapsed = end - start;
System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis());
for (val pair : list) {
System.out.println(pair.getFirst() + " " + pair.getSecond()) ;
}
}
}

Some files were not shown because too many files have changed in this diff Show More