Remove more unused modules
parent
fa8537f0c7
commit
ee06fdd16f
|
@ -1,64 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-spark-inference-client</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-spark-inference-client</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-model</artifactId>
|
|
||||||
<version>${project.parent.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.mashape.unirest</groupId>
|
|
||||||
<artifactId>unirest-java</artifactId>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,292 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.client;
|
|
||||||
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import com.mashape.unirest.http.exceptions.UnirestException;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.model.*;
|
|
||||||
import org.datavec.spark.inference.model.service.DataVecTransformService;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Slf4j
|
|
||||||
public class DataVecTransformClient implements DataVecTransformService {
|
|
||||||
private String url;
|
|
||||||
|
|
||||||
static {
|
|
||||||
// Only one time
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param transformProcess
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
|
||||||
try {
|
|
||||||
String s = transformProcess.toJson();
|
|
||||||
Unirest.post(url + "/transformprocess").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(s).asJson();
|
|
||||||
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in setCSVTransformProcess()", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public TransformProcess getCSVTransformProcess() {
|
|
||||||
try {
|
|
||||||
String s = Unirest.get(url + "/transformprocess").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").asString().getBody();
|
|
||||||
return TransformProcess.fromJson(s);
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in getCSVTransformProcess()",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ImageTransformProcess getImageTransformProcess() {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
|
|
||||||
try {
|
|
||||||
SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.body(transform).asObject(SingleCSVRecord.class).getBody();
|
|
||||||
return singleCsvRecord;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformIncremental(SingleCSVRecord)",e);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"TRUE")
|
|
||||||
.body(batchCSVRecord)
|
|
||||||
.asObject(SequenceBatchCSVRecord.class)
|
|
||||||
.getBody();
|
|
||||||
return batchCSVRecord1;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"FALSE")
|
|
||||||
.body(batchCSVRecord)
|
|
||||||
.asObject(BatchCSVRecord.class)
|
|
||||||
.getBody();
|
|
||||||
return batchCSVRecord1;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transform(BatchCSVRecord)", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(batchCSVRecord)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
return batchArray1;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformArray(BatchCSVRecord)",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
|
||||||
try {
|
|
||||||
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
return array;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
|
||||||
try {
|
|
||||||
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
|
||||||
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
return array;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformSequenceArrayIncremental",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
|
||||||
.body(batchCSVRecord)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
return batchArray1;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformSequenceArray",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
|
||||||
.body(batchCSVRecord)
|
|
||||||
.asObject(SequenceBatchCSVRecord.class).getBody();
|
|
||||||
return batchCSVRecord1;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformSequence");
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
|
||||||
try {
|
|
||||||
SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json")
|
|
||||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
|
||||||
.body(transform).asObject(SequenceBatchCSVRecord.class).getBody();
|
|
||||||
return singleCsvRecord;
|
|
||||||
} catch (UnirestException e) {
|
|
||||||
log.error("Error in transformSequenceIncremental");
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.transform.client;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.transform.client";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,139 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.transform.client;
|
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
|
||||||
import org.datavec.spark.inference.client.DataVecTransformClient;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.ServerSocket;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assume.assumeNotNull;
|
|
||||||
|
|
||||||
public class DataVecTransformClientTest {
|
|
||||||
private static CSVSparkTransformServer server;
|
|
||||||
private static int port = getAvailablePort();
|
|
||||||
private static DataVecTransformClient client;
|
|
||||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
private static TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
|
||||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void beforeClass() throws Exception {
|
|
||||||
FileUtils.write(fileSave, transformProcess.toJson());
|
|
||||||
fileSave.deleteOnExit();
|
|
||||||
server = new CSVSparkTransformServer();
|
|
||||||
server.runMain(new String[] {"-dp", String.valueOf(port)});
|
|
||||||
|
|
||||||
client = new DataVecTransformClient("http://localhost:" + port);
|
|
||||||
client.setCSVTransformProcess(transformProcess);
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void afterClass() throws Exception {
|
|
||||||
server.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSequenceClient() {
|
|
||||||
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
|
|
||||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
|
|
||||||
List<BatchCSVRecord> batchCSVRecordList = new ArrayList<>();
|
|
||||||
for(int i = 0; i < 5; i++) {
|
|
||||||
batchCSVRecordList.add(batchCSVRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
sequenceBatchCSVRecord.add(batchCSVRecordList);
|
|
||||||
|
|
||||||
SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord);
|
|
||||||
assumeNotNull(sequenceBatchCSVRecord1);
|
|
||||||
|
|
||||||
Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord);
|
|
||||||
assumeNotNull(array);
|
|
||||||
|
|
||||||
Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord);
|
|
||||||
assumeNotNull(incrementalBody);
|
|
||||||
|
|
||||||
Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord);
|
|
||||||
assumeNotNull(incrementalSequenceBody);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testRecord() throws Exception {
|
|
||||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
|
||||||
SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord);
|
|
||||||
assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size());
|
|
||||||
Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord);
|
|
||||||
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
assumeNotNull(arr);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBatchRecord() throws Exception {
|
|
||||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
|
|
||||||
BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord);
|
|
||||||
assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size());
|
|
||||||
|
|
||||||
Base64NDArrayBody body = client.transformArray(batchCSVRecord);
|
|
||||||
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
assumeNotNull(arr);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static int getAvailablePort() {
|
|
||||||
try {
|
|
||||||
ServerSocket socket = new ServerSocket(0);
|
|
||||||
try {
|
|
||||||
return socket.getLocalPort();
|
|
||||||
} finally {
|
|
||||||
socket.close();
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,6 +0,0 @@
|
||||||
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
|
|
||||||
play.modules.enabled += io.skymind.skil.service.PredictionModule
|
|
||||||
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
|
|
||||||
play.server.pidfile.path=/tmp/RUNNING_PID
|
|
||||||
|
|
||||||
play.server.http.port = 9600
|
|
|
@ -1,63 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-spark-inference-model</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-spark-inference-model</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
<version>${datavec.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-local</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,286 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.apache.arrow.memory.BufferAllocator;
|
|
||||||
import org.apache.arrow.memory.RootAllocator;
|
|
||||||
import org.apache.arrow.vector.FieldVector;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.util.ndarray.RecordConverter;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.arrow.ArrowConverter;
|
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
|
||||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
|
||||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static org.datavec.arrow.ArrowConverter.*;
|
|
||||||
import static org.datavec.local.transforms.LocalTransformExecutor.execute;
|
|
||||||
import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Slf4j
|
|
||||||
public class CSVSparkTransform {
|
|
||||||
@Getter
|
|
||||||
private TransformProcess transformProcess;
|
|
||||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a raw record via
|
|
||||||
* the {@link TransformProcess}
|
|
||||||
* to a base 64ed ndarray
|
|
||||||
* @param batch the record to convert
|
|
||||||
* @return teh base 64ed ndarray
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException {
|
|
||||||
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
|
|
||||||
bufferAllocator,transformProcess.getInitialSchema(),
|
|
||||||
batch.getRecordsAsString()),
|
|
||||||
transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
|
|
||||||
ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted;
|
|
||||||
INDArray convert = ArrowConverter.toArray(arrowRecordBatch);
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a raw record via
|
|
||||||
* the {@link TransformProcess}
|
|
||||||
* to a base 64ed ndarray
|
|
||||||
* @param record the record to convert
|
|
||||||
* @return the base 64ed ndarray
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException {
|
|
||||||
List<Writable> record2 = toArrowWritablesSingle(
|
|
||||||
toArrowColumnsStringSingle(bufferAllocator,
|
|
||||||
transformProcess.getInitialSchema(),record.getValues()),
|
|
||||||
transformProcess.getInitialSchema());
|
|
||||||
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
|
|
||||||
INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs the transform process
|
|
||||||
* @param batch the record to transform
|
|
||||||
* @return the transformed record
|
|
||||||
*/
|
|
||||||
public BatchCSVRecord transform(BatchCSVRecord batch) {
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
|
|
||||||
bufferAllocator,transformProcess.getInitialSchema(),
|
|
||||||
batch.getRecordsAsString()),
|
|
||||||
transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
int numCols = converted.get(0).size();
|
|
||||||
for (int row = 0; row < converted.size(); row++) {
|
|
||||||
String[] values = new String[numCols];
|
|
||||||
for (int i = 0; i < values.length; i++)
|
|
||||||
values[i] = converted.get(row).get(i).toString();
|
|
||||||
batchCSVRecord.add(new SingleCSVRecord(values));
|
|
||||||
}
|
|
||||||
|
|
||||||
return batchCSVRecord;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs the transform process
|
|
||||||
* @param record the record to transform
|
|
||||||
* @return the transformed record
|
|
||||||
*/
|
|
||||||
public SingleCSVRecord transform(SingleCSVRecord record) {
|
|
||||||
List<Writable> record2 = toArrowWritablesSingle(
|
|
||||||
toArrowColumnsStringSingle(bufferAllocator,
|
|
||||||
transformProcess.getInitialSchema(),record.getValues()),
|
|
||||||
transformProcess.getInitialSchema());
|
|
||||||
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
|
|
||||||
String[] values = new String[finalRecord.size()];
|
|
||||||
for (int i = 0; i < values.length; i++)
|
|
||||||
values[i] = finalRecord.get(i).toString();
|
|
||||||
return new SingleCSVRecord(values);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
|
||||||
/**
|
|
||||||
* Sequence schema?
|
|
||||||
*/
|
|
||||||
List<List<List<Writable>>> converted = executeToSequence(
|
|
||||||
toArrowWritables(toArrowColumnsStringTimeSeries(
|
|
||||||
bufferAllocator, transformProcess.getInitialSchema(),
|
|
||||||
Arrays.asList(transform.getRecordsAsString())),
|
|
||||||
transformProcess.getInitialSchema()), transformProcess);
|
|
||||||
|
|
||||||
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
|
|
||||||
for (int i = 0; i < converted.size(); i++) {
|
|
||||||
BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i));
|
|
||||||
batchCSVRecord.add(Arrays.asList(batchCSVRecord1));
|
|
||||||
}
|
|
||||||
|
|
||||||
return batchCSVRecord;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchCSVRecordSequence
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) {
|
|
||||||
List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString();
|
|
||||||
boolean allSameLength = true;
|
|
||||||
Integer length = null;
|
|
||||||
for(List<List<String>> record : recordsAsString) {
|
|
||||||
if(length == null) {
|
|
||||||
length = record.size();
|
|
||||||
}
|
|
||||||
else if(record.size() != length) {
|
|
||||||
allSameLength = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(allSameLength) {
|
|
||||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString);
|
|
||||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,
|
|
||||||
transformProcess.getInitialSchema(),
|
|
||||||
recordsAsString.get(0).get(0).size());
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
|
||||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* TODO: optimize
|
|
||||||
* @param batchCSVRecordSequence
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) {
|
|
||||||
List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString();
|
|
||||||
boolean allSameLength = true;
|
|
||||||
Integer length = null;
|
|
||||||
for(List<List<String>> record : strings) {
|
|
||||||
if(length == null) {
|
|
||||||
length = record.size();
|
|
||||||
}
|
|
||||||
else if(record.size() != length) {
|
|
||||||
allSameLength = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(allSameLength) {
|
|
||||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
|
|
||||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
|
||||||
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
|
|
||||||
try {
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new IllegalStateException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
|
|
||||||
try {
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new IllegalStateException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
|
||||||
List<List<List<Writable>>> converted = executeToSequence(toArrowWritables(toArrowColumnsString(
|
|
||||||
bufferAllocator,transformProcess.getInitialSchema(),
|
|
||||||
singleCsvRecord.getRecordsAsString()),
|
|
||||||
transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted;
|
|
||||||
INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch);
|
|
||||||
try {
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("",e);
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString();
|
|
||||||
boolean allSameLength = true;
|
|
||||||
Integer length = null;
|
|
||||||
for(List<List<String>> record : strings) {
|
|
||||||
if(length == null) {
|
|
||||||
length = record.size();
|
|
||||||
}
|
|
||||||
else if(record.size() != length) {
|
|
||||||
allSameLength = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if(allSameLength) {
|
|
||||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
|
|
||||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
|
||||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
|
||||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,64 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.datavec.image.data.ImageWritable;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class ImageSparkTransform {
|
|
||||||
@Getter
|
|
||||||
private ImageTransformProcess imageTransformProcess;
|
|
||||||
|
|
||||||
public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException {
|
|
||||||
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri());
|
|
||||||
INDArray finalRecord = imageTransformProcess.executeArray(record2);
|
|
||||||
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException {
|
|
||||||
List<INDArray> records = new ArrayList<>();
|
|
||||||
|
|
||||||
for (SingleImageRecord imgRecord : batch.getRecords()) {
|
|
||||||
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri());
|
|
||||||
INDArray finalRecord = imageTransformProcess.executeArray(record2);
|
|
||||||
records.add(finalRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()]));
|
|
||||||
|
|
||||||
return new Base64NDArrayBody(Nd4jBase64.base64String(array));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class Base64NDArrayBody {
|
|
||||||
private String ndarray;
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class BatchCSVRecord implements Serializable {
|
|
||||||
private List<SingleCSVRecord> records;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the records as a list of strings
|
|
||||||
* (basically the underlying values for
|
|
||||||
* {@link SingleCSVRecord})
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public List<List<String>> getRecordsAsString() {
|
|
||||||
if(records == null)
|
|
||||||
records = new ArrayList<>();
|
|
||||||
List<List<String>> ret = new ArrayList<>();
|
|
||||||
for(SingleCSVRecord csvRecord : records) {
|
|
||||||
ret.add(csvRecord.getValues());
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create a batch csv record
|
|
||||||
* from a list of writables.
|
|
||||||
* @param batch
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static BatchCSVRecord fromWritables(List<List<Writable>> batch) {
|
|
||||||
List <SingleCSVRecord> records = new ArrayList<>(batch.size());
|
|
||||||
for(List<Writable> list : batch) {
|
|
||||||
List<String> add = new ArrayList<>(list.size());
|
|
||||||
for(Writable writable : list) {
|
|
||||||
add.add(writable.toString());
|
|
||||||
}
|
|
||||||
records.add(new SingleCSVRecord(add));
|
|
||||||
}
|
|
||||||
|
|
||||||
return BatchCSVRecord.builder().records(records).build();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a record
|
|
||||||
* @param record
|
|
||||||
*/
|
|
||||||
public void add(SingleCSVRecord record) {
|
|
||||||
if (records == null)
|
|
||||||
records = new ArrayList<>();
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a batch record based on a dataset
|
|
||||||
* @param dataSet the dataset to get the batch record for
|
|
||||||
* @return the batch record
|
|
||||||
*/
|
|
||||||
public static BatchCSVRecord fromDataSet(DataSet dataSet) {
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < dataSet.numExamples(); i++) {
|
|
||||||
batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
return batchCSVRecord;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,50 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.net.URI;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class BatchImageRecord {
|
|
||||||
private List<SingleImageRecord> records;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a record
|
|
||||||
* @param record
|
|
||||||
*/
|
|
||||||
public void add(SingleImageRecord record) {
|
|
||||||
if (records == null)
|
|
||||||
records = new ArrayList<>();
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void add(URI uri) {
|
|
||||||
this.add(new SingleImageRecord(uri));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,106 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class SequenceBatchCSVRecord implements Serializable {
|
|
||||||
private List<List<BatchCSVRecord>> records;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a record
|
|
||||||
* @param record
|
|
||||||
*/
|
|
||||||
public void add(List<BatchCSVRecord> record) {
|
|
||||||
if (records == null)
|
|
||||||
records = new ArrayList<>();
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the records as a list of strings directly
|
|
||||||
* (this basically "unpacks" the objects)
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public List<List<List<String>>> getRecordsAsString() {
|
|
||||||
if(records == null)
|
|
||||||
Collections.emptyList();
|
|
||||||
List<List<List<String>>> ret = new ArrayList<>(records.size());
|
|
||||||
for(List<BatchCSVRecord> record : records) {
|
|
||||||
List<List<String>> add = new ArrayList<>();
|
|
||||||
for(BatchCSVRecord batchCSVRecord : record) {
|
|
||||||
for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) {
|
|
||||||
add.add(singleCSVRecord.getValues());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ret.add(add);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Convert a writables time series to a sequence batch
|
|
||||||
* @param input
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) {
|
|
||||||
SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord();
|
|
||||||
for(int i = 0; i < input.size(); i++) {
|
|
||||||
ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i))));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a batch record based on a dataset
|
|
||||||
* @param dataSet the dataset to get the batch record for
|
|
||||||
* @return the batch record
|
|
||||||
*/
|
|
||||||
public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) {
|
|
||||||
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
|
|
||||||
for (int i = 0; i < dataSet.numFeatureArrays(); i++) {
|
|
||||||
batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i)))));
|
|
||||||
}
|
|
||||||
|
|
||||||
return batchCSVRecord;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class SingleCSVRecord implements Serializable {
|
|
||||||
private List<String> values;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create from an array of values uses list internally)
|
|
||||||
* @param values
|
|
||||||
*/
|
|
||||||
public SingleCSVRecord(String...values) {
|
|
||||||
this.values = Arrays.asList(values);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Instantiate a csv record from a vector
|
|
||||||
* given either an input dataset and a
|
|
||||||
* one hot matrix, the index will be appended to
|
|
||||||
* the end of the record, or for regression
|
|
||||||
* it will append all values in the labels
|
|
||||||
* @param row the input vectors
|
|
||||||
* @return the record from this {@link DataSet}
|
|
||||||
*/
|
|
||||||
public static SingleCSVRecord fromRow(DataSet row) {
|
|
||||||
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
|
|
||||||
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
|
|
||||||
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
|
|
||||||
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
|
|
||||||
//classification
|
|
||||||
SingleCSVRecord record;
|
|
||||||
int idx = 0;
|
|
||||||
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
|
|
||||||
String[] values = new String[row.getFeatures().columns() + 1];
|
|
||||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
|
||||||
}
|
|
||||||
int maxIdx = 0;
|
|
||||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
|
||||||
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
|
|
||||||
maxIdx = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
values[idx++] = String.valueOf(maxIdx);
|
|
||||||
record = new SingleCSVRecord(values);
|
|
||||||
}
|
|
||||||
//regression (any number of values)
|
|
||||||
else {
|
|
||||||
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
|
|
||||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
record = new SingleCSVRecord(values);
|
|
||||||
|
|
||||||
}
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,34 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.net.URI;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class SingleImageRecord {
|
|
||||||
private URI uri;
|
|
||||||
}
|
|
|
@ -1,131 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.model.service;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.model.*;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
public interface DataVecTransformService {
|
|
||||||
|
|
||||||
String SEQUENCE_OR_NOT_HEADER = "Sequence";
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param transformProcess
|
|
||||||
*/
|
|
||||||
void setCSVTransformProcess(TransformProcess transformProcess);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param imageTransformProcess
|
|
||||||
*/
|
|
||||||
void setImageTransformProcess(ImageTransformProcess imageTransformProcess);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
TransformProcess getCSVTransformProcess();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
ImageTransformProcess getImageTransformProcess();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord);
|
|
||||||
|
|
||||||
SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
BatchCSVRecord transform(BatchCSVRecord batchCSVRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param singleImageRecord
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchImageRecord
|
|
||||||
* @return
|
|
||||||
* @throws IOException
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform);
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.spark.transform";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class BatchCSVRecordTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBatchRecordCreationFromDataSet() {
|
|
||||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet);
|
|
||||||
assertEquals(2, batchCSVRecord.getRecords().size());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,212 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.api.transform.transform.integer.BaseIntegerTransform;
|
|
||||||
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
|
||||||
import org.datavec.api.writable.DoubleWritable;
|
|
||||||
import org.datavec.api.writable.Text;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.spark.inference.model.CSVSparkTransform;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
import static org.junit.Assert.*;
|
|
||||||
|
|
||||||
public class CSVSparkTransformTest {
|
|
||||||
@Test
|
|
||||||
public void testTransformer() throws Exception {
|
|
||||||
List<Writable> input = new ArrayList<>();
|
|
||||||
input.add(new DoubleWritable(1.0));
|
|
||||||
input.add(new DoubleWritable(2.0));
|
|
||||||
|
|
||||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
List<Writable> output = new ArrayList<>();
|
|
||||||
output.add(new Text("1.0"));
|
|
||||||
output.add(new Text("2.0"));
|
|
||||||
|
|
||||||
TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
|
||||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
|
||||||
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
|
|
||||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
assertTrue(fromBase64.isVector());
|
|
||||||
// System.out.println("Base 64ed array " + fromBase64);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testTransformerBatch() throws Exception {
|
|
||||||
List<Writable> input = new ArrayList<>();
|
|
||||||
input.add(new DoubleWritable(1.0));
|
|
||||||
input.add(new DoubleWritable(2.0));
|
|
||||||
|
|
||||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
List<Writable> output = new ArrayList<>();
|
|
||||||
output.add(new Text("1.0"));
|
|
||||||
output.add(new Text("2.0"));
|
|
||||||
|
|
||||||
TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
|
||||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < 3; i++)
|
|
||||||
batchCSVRecord.add(record);
|
|
||||||
//data type is string, unable to convert
|
|
||||||
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
|
|
||||||
/* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
|
|
||||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
assertTrue(fromBase64.isMatrix());
|
|
||||||
System.out.println("Base 64ed array " + fromBase64); */
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSingleBatchSequence() throws Exception {
|
|
||||||
List<Writable> input = new ArrayList<>();
|
|
||||||
input.add(new DoubleWritable(1.0));
|
|
||||||
input.add(new DoubleWritable(2.0));
|
|
||||||
|
|
||||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
List<Writable> output = new ArrayList<>();
|
|
||||||
output.add(new Text("1.0"));
|
|
||||||
output.add(new Text("2.0"));
|
|
||||||
|
|
||||||
TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
|
||||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < 3; i++)
|
|
||||||
batchCSVRecord.add(record);
|
|
||||||
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
|
|
||||||
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
|
|
||||||
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
|
|
||||||
Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
|
|
||||||
INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray());
|
|
||||||
|
|
||||||
|
|
||||||
//ensure accumulation
|
|
||||||
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
|
|
||||||
sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
|
|
||||||
assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape());
|
|
||||||
|
|
||||||
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
|
|
||||||
assertNotNull(transformed.getRecords());
|
|
||||||
// System.out.println(transformed);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSpecificSequence() throws Exception {
|
|
||||||
final Schema schema = new Schema.Builder()
|
|
||||||
.addColumnsString("action")
|
|
||||||
.build();
|
|
||||||
|
|
||||||
final TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
|
||||||
.removeAllColumnsExceptFor("action")
|
|
||||||
.transform(new ConverToLowercase("action"))
|
|
||||||
.convertToSequence()
|
|
||||||
.transform(new TextToCharacterIndexTransform("action", "action_sequence",
|
|
||||||
defaultCharIndex(), false))
|
|
||||||
.integerToOneHot("action_sequence",0,29)
|
|
||||||
.build();
|
|
||||||
|
|
||||||
final String[] data1 = new String[] { "test1" };
|
|
||||||
final String[] data2 = new String[] { "test2" };
|
|
||||||
final BatchCSVRecord batchCsvRecord = new BatchCSVRecord(
|
|
||||||
Arrays.asList(
|
|
||||||
new SingleCSVRecord(data1),
|
|
||||||
new SingleCSVRecord(data2)));
|
|
||||||
|
|
||||||
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
|
|
||||||
// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
|
|
||||||
transform.transformSequenceIncremental(batchCsvRecord);
|
|
||||||
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private static Map<Character,Integer> defaultCharIndex() {
|
|
||||||
Map<Character,Integer> ret = new TreeMap<>();
|
|
||||||
|
|
||||||
ret.put('a',0);
|
|
||||||
ret.put('b',1);
|
|
||||||
ret.put('c',2);
|
|
||||||
ret.put('d',3);
|
|
||||||
ret.put('e',4);
|
|
||||||
ret.put('f',5);
|
|
||||||
ret.put('g',6);
|
|
||||||
ret.put('h',7);
|
|
||||||
ret.put('i',8);
|
|
||||||
ret.put('j',9);
|
|
||||||
ret.put('k',10);
|
|
||||||
ret.put('l',11);
|
|
||||||
ret.put('m',12);
|
|
||||||
ret.put('n',13);
|
|
||||||
ret.put('o',14);
|
|
||||||
ret.put('p',15);
|
|
||||||
ret.put('q',16);
|
|
||||||
ret.put('r',17);
|
|
||||||
ret.put('s',18);
|
|
||||||
ret.put('t',19);
|
|
||||||
ret.put('u',20);
|
|
||||||
ret.put('v',21);
|
|
||||||
ret.put('w',22);
|
|
||||||
ret.put('x',23);
|
|
||||||
ret.put('y',24);
|
|
||||||
ret.put('z',25);
|
|
||||||
ret.put('/',26);
|
|
||||||
ret.put(' ',27);
|
|
||||||
ret.put('(',28);
|
|
||||||
ret.put(')',29);
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static class ConverToLowercase extends BaseIntegerTransform {
|
|
||||||
public ConverToLowercase(String column) {
|
|
||||||
super(column);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Text map(Writable writable) {
|
|
||||||
return new Text(writable.toString().toLowerCase());
|
|
||||||
}
|
|
||||||
|
|
||||||
public Object map(Object input) {
|
|
||||||
return new Text(input.toString().toLowerCase());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,86 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.ImageSparkTransform;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class ImageSparkTransformTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testSingleImageSparkTransform() throws Exception {
|
|
||||||
int seed = 12345;
|
|
||||||
|
|
||||||
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
|
|
||||||
|
|
||||||
SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI());
|
|
||||||
|
|
||||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
|
|
||||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
|
||||||
|
|
||||||
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
|
|
||||||
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
|
|
||||||
|
|
||||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
// System.out.println("Base 64ed array " + fromBase64);
|
|
||||||
assertEquals(1, fromBase64.size(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testBatchImageSparkTransform() throws Exception {
|
|
||||||
int seed = 12345;
|
|
||||||
|
|
||||||
File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
|
|
||||||
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile();
|
|
||||||
File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile();
|
|
||||||
|
|
||||||
BatchImageRecord batch = new BatchImageRecord();
|
|
||||||
batch.add(f0.toURI());
|
|
||||||
batch.add(f1.toURI());
|
|
||||||
batch.add(f2.toURI());
|
|
||||||
|
|
||||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
|
|
||||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
|
||||||
|
|
||||||
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
|
|
||||||
Base64NDArrayBody body = imgSparkTransform.toArray(batch);
|
|
||||||
|
|
||||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
|
||||||
// System.out.println("Base 64ed array " + fromBase64);
|
|
||||||
assertEquals(3, fromBase64.size(0));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.fail;
|
|
||||||
|
|
||||||
public class SingleCSVRecordTest {
|
|
||||||
|
|
||||||
@Test(expected = IllegalArgumentException.class)
|
|
||||||
public void testVectorAssertion() {
|
|
||||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1));
|
|
||||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet);
|
|
||||||
fail(singleCsvRecord.toString() + " should have thrown an exception");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testVectorOneHotLabel() {
|
|
||||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}}));
|
|
||||||
|
|
||||||
//assert
|
|
||||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
|
|
||||||
assertEquals(3, singleCsvRecord.getValues().size());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testVectorRegression() {
|
|
||||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
|
|
||||||
|
|
||||||
//assert
|
|
||||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
|
|
||||||
assertEquals(4, singleCsvRecord.getValues().size());
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
|
|
||||||
public class SingleImageRecordTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testImageRecord() throws Exception {
|
|
||||||
File f = testDir.newFolder();
|
|
||||||
new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f);
|
|
||||||
File f0 = new File(f, "class0/0.jpg");
|
|
||||||
File f1 = new File(f, "/class1/A.jpg");
|
|
||||||
|
|
||||||
SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI());
|
|
||||||
|
|
||||||
// need jackson test?
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,154 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-spark-inference-server</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
|
||||||
<scala.version>2.11.12</scala.version>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
<maven.compiler.source>1.8</maven.compiler.source>
|
|
||||||
<maven.compiler.target>1.8</maven.compiler.target>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark-inference-model</artifactId>
|
|
||||||
<version>${datavec.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-spark_2.11</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>joda-time</groupId>
|
|
||||||
<artifactId>joda-time</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.hibernate</groupId>
|
|
||||||
<artifactId>hibernate-validator</artifactId>
|
|
||||||
<version>${hibernate.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-library</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.scala-lang</groupId>
|
|
||||||
<artifactId>scala-reflect</artifactId>
|
|
||||||
<version>${scala.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-java_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
<exclusions>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>com.google.code.findbugs</groupId>
|
|
||||||
<artifactId>jsr305</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
<exclusion>
|
|
||||||
<groupId>net.jodah</groupId>
|
|
||||||
<artifactId>typetools</artifactId>
|
|
||||||
</exclusion>
|
|
||||||
</exclusions>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>net.jodah</groupId>
|
|
||||||
<artifactId>typetools</artifactId>
|
|
||||||
<version>${jodah.typetools.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-json_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-server_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.play</groupId>
|
|
||||||
<artifactId>play-netty-server_2.11</artifactId>
|
|
||||||
<version>${playframework.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.typesafe.akka</groupId>
|
|
||||||
<artifactId>akka-cluster_2.11</artifactId>
|
|
||||||
<version>2.5.23</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.mashape.unirest</groupId>
|
|
||||||
<artifactId>unirest-java</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.beust</groupId>
|
|
||||||
<artifactId>jcommander</artifactId>
|
|
||||||
<version>${jcommander.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.spark</groupId>
|
|
||||||
<artifactId>spark-core_2.11</artifactId>
|
|
||||||
<version>${spark.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,352 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.server;
|
|
||||||
|
|
||||||
import com.beust.jcommander.JCommander;
|
|
||||||
import com.beust.jcommander.ParameterException;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.CSVSparkTransform;
|
|
||||||
import org.datavec.spark.inference.model.model.*;
|
|
||||||
import play.BuiltInComponents;
|
|
||||||
import play.Mode;
|
|
||||||
import play.routing.Router;
|
|
||||||
import play.routing.RoutingDsl;
|
|
||||||
import play.server.Server;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.Base64;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static play.mvc.Results.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Data
|
|
||||||
public class CSVSparkTransformServer extends SparkTransformServer {
|
|
||||||
private CSVSparkTransform transform;
|
|
||||||
|
|
||||||
public void runMain(String[] args) throws Exception {
|
|
||||||
JCommander jcmdr = new JCommander(this);
|
|
||||||
|
|
||||||
try {
|
|
||||||
jcmdr.parse(args);
|
|
||||||
} catch (ParameterException e) {
|
|
||||||
//User provides invalid input -> print the usage info
|
|
||||||
jcmdr.usage();
|
|
||||||
if (jsonPath == null)
|
|
||||||
System.err.println("Json path parameter is missing.");
|
|
||||||
try {
|
|
||||||
Thread.sleep(500);
|
|
||||||
} catch (Exception e2) {
|
|
||||||
}
|
|
||||||
System.exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (jsonPath != null) {
|
|
||||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
|
||||||
TransformProcess transformProcess = TransformProcess.fromJson(json);
|
|
||||||
transform = new CSVSparkTransform(transformProcess);
|
|
||||||
} else {
|
|
||||||
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
|
|
||||||
+ "to /transformprocess");
|
|
||||||
}
|
|
||||||
|
|
||||||
//Set play secret key, if required
|
|
||||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
|
||||||
String crypto = System.getProperty("play.crypto.secret");
|
|
||||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
|
|
||||||
byte[] newCrypto = new byte[1024];
|
|
||||||
|
|
||||||
new Random().nextBytes(newCrypto);
|
|
||||||
|
|
||||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
|
||||||
System.setProperty("play.crypto.secret", base64);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Router createRouter(BuiltInComponents b){
|
|
||||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
|
|
||||||
|
|
||||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
if (transform == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(transform.getTransformProcess().toJson()).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in GET /transformprocess",e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
|
|
||||||
setCSVTransformProcess(transformProcess);
|
|
||||||
log.info("Transform process initialized");
|
|
||||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in POST /transformprocess",e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformincremental").routingTo(req -> {
|
|
||||||
if (isSequence(req)) {
|
|
||||||
try {
|
|
||||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
|
||||||
if (record == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformincremental", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
|
||||||
if (record == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformincremental", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transform").routingTo(req -> {
|
|
||||||
if (isSequence(req)) {
|
|
||||||
try {
|
|
||||||
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
|
|
||||||
if (batch == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transform", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
|
||||||
BatchCSVRecord batch = transform(input);
|
|
||||||
if (batch == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transform", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
|
||||||
if (isSequence(req)) {
|
|
||||||
try {
|
|
||||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
|
||||||
if (record == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformincrementalarray", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
|
||||||
if (record == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformincrementalarray", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
|
||||||
if (isSequence(req)) {
|
|
||||||
try {
|
|
||||||
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
|
|
||||||
if (batchCSVRecord == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformarray", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
try {
|
|
||||||
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
|
||||||
if (batchCSVRecord == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error in /transformarray", e);
|
|
||||||
return internalServerError(e.getMessage());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return routingDsl.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
new CSVSparkTransformServer().runMain(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param transformProcess
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
|
||||||
this.transform = new CSVSparkTransform(transformProcess);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
|
||||||
log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass());
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public TransformProcess getCSVTransformProcess() {
|
|
||||||
return transform.getTransformProcess();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ImageTransformProcess getImageTransformProcess() {
|
|
||||||
log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass());
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
/**
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
|
||||||
return this.transform.transformSequenceIncremental(transform);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
return transform.transformSequence(batchCSVRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
return this.transform.transformSequenceArray(batchCSVRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
|
||||||
return this.transform.transformSequenceArrayIncremental(singleCsvRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param transform
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
|
|
||||||
return this.transform.transform(transform);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
return this.transform.transform(batchCSVRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
|
||||||
return transform.transform(batchCSVRecord);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param batchCSVRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
|
||||||
try {
|
|
||||||
return this.transform.toArray(batchCSVRecord);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("Error in transformArray",e);
|
|
||||||
throw new IllegalStateException("Transform array shouldn't throw exception");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param singleCsvRecord
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
|
||||||
try {
|
|
||||||
return this.transform.toArray(singleCsvRecord);
|
|
||||||
} catch (IOException e) {
|
|
||||||
log.error("Error in transformArrayIncremental",e);
|
|
||||||
throw new IllegalStateException("Transform array shouldn't throw exception");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
|
|
||||||
log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass());
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
|
|
||||||
log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass());
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,261 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.server;
|
|
||||||
|
|
||||||
import com.beust.jcommander.JCommander;
|
|
||||||
import com.beust.jcommander.ParameterException;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.model.ImageSparkTransform;
|
|
||||||
import org.datavec.spark.inference.model.model.*;
|
|
||||||
import play.BuiltInComponents;
|
|
||||||
import play.Mode;
|
|
||||||
import play.libs.Files;
|
|
||||||
import play.mvc.Http;
|
|
||||||
import play.routing.Router;
|
|
||||||
import play.routing.RoutingDsl;
|
|
||||||
import play.server.Server;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static play.mvc.Results.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Data
|
|
||||||
public class ImageSparkTransformServer extends SparkTransformServer {
|
|
||||||
private ImageSparkTransform transform;
|
|
||||||
|
|
||||||
public void runMain(String[] args) throws Exception {
|
|
||||||
JCommander jcmdr = new JCommander(this);
|
|
||||||
|
|
||||||
try {
|
|
||||||
jcmdr.parse(args);
|
|
||||||
} catch (ParameterException e) {
|
|
||||||
//User provides invalid input -> print the usage info
|
|
||||||
jcmdr.usage();
|
|
||||||
if (jsonPath == null)
|
|
||||||
System.err.println("Json path parameter is missing.");
|
|
||||||
try {
|
|
||||||
Thread.sleep(500);
|
|
||||||
} catch (Exception e2) {
|
|
||||||
}
|
|
||||||
System.exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (jsonPath != null) {
|
|
||||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
|
||||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
|
|
||||||
transform = new ImageSparkTransform(transformProcess);
|
|
||||||
} else {
|
|
||||||
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
|
|
||||||
+ "to /transformprocess");
|
|
||||||
}
|
|
||||||
|
|
||||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected Router createRouter(BuiltInComponents builtInComponents){
|
|
||||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
|
||||||
|
|
||||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
if (transform == null)
|
|
||||||
return badRequest();
|
|
||||||
log.info("Transform process initialized");
|
|
||||||
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
|
|
||||||
setImageTransformProcess(transformProcess);
|
|
||||||
log.info("Transform process initialized");
|
|
||||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
|
|
||||||
if (record == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformincrementalimage").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
|
||||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
|
||||||
if (files.isEmpty() || files.get(0).getRef() == null ) {
|
|
||||||
return badRequest();
|
|
||||||
}
|
|
||||||
|
|
||||||
File file = files.get(0).getRef().path().toFile();
|
|
||||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
|
||||||
|
|
||||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
|
|
||||||
if (batch == null)
|
|
||||||
return badRequest();
|
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
routingDsl.POST("/transformimage").routingTo(req -> {
|
|
||||||
try {
|
|
||||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
|
||||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
|
||||||
if (files.size() == 0) {
|
|
||||||
return badRequest();
|
|
||||||
}
|
|
||||||
|
|
||||||
List<SingleImageRecord> records = new ArrayList<>();
|
|
||||||
|
|
||||||
for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
|
|
||||||
Files.TemporaryFile file = filePart.getRef();
|
|
||||||
if (file != null) {
|
|
||||||
SingleImageRecord record = new SingleImageRecord(file.path().toUri());
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BatchImageRecord batch = new BatchImageRecord(records);
|
|
||||||
|
|
||||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("",e);
|
|
||||||
return internalServerError();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return routingDsl.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
|
||||||
this.transform = new ImageSparkTransform(imageTransformProcess);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public TransformProcess getCSVTransformProcess() {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public ImageTransformProcess getImageTransformProcess() {
|
|
||||||
return transform.getImageTransformProcess();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
|
||||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException {
|
|
||||||
return transform.toArray(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException {
|
|
||||||
return transform.toArray(batch);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
new ImageSparkTransformServer().runMain(args);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,67 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.server;
|
|
||||||
|
|
||||||
import com.beust.jcommander.Parameter;
|
|
||||||
import com.fasterxml.jackson.databind.JsonNode;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.service.DataVecTransformService;
|
|
||||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
|
||||||
import play.mvc.Http;
|
|
||||||
import play.server.Server;
|
|
||||||
|
|
||||||
public abstract class SparkTransformServer implements DataVecTransformService {
|
|
||||||
@Parameter(names = {"-j", "--jsonPath"}, arity = 1)
|
|
||||||
protected String jsonPath = null;
|
|
||||||
@Parameter(names = {"-dp", "--dataVecPort"}, arity = 1)
|
|
||||||
protected int port = 9000;
|
|
||||||
@Parameter(names = {"-dt", "--dataType"}, arity = 1)
|
|
||||||
private TransformDataType transformDataType = null;
|
|
||||||
protected Server server;
|
|
||||||
protected static ObjectMapper objectMapper = new ObjectMapper();
|
|
||||||
protected static String contentType = "application/json";
|
|
||||||
|
|
||||||
public abstract void runMain(String[] args) throws Exception;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stop the server
|
|
||||||
*/
|
|
||||||
public void stop() {
|
|
||||||
if (server != null)
|
|
||||||
server.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected boolean isSequence(Http.Request request) {
|
|
||||||
return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
|
|
||||||
&& request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
|
|
||||||
}
|
|
||||||
|
|
||||||
protected String getJsonText(Http.Request request) {
|
|
||||||
JsonNode tryJson = request.body().asJson();
|
|
||||||
if (tryJson != null)
|
|
||||||
return tryJson.toString();
|
|
||||||
else
|
|
||||||
return request.body().asText();
|
|
||||||
}
|
|
||||||
|
|
||||||
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
|
||||||
}
|
|
|
@ -1,76 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.server;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
|
|
||||||
import java.io.InvalidClassException;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Slf4j
|
|
||||||
public class SparkTransformServerChooser {
|
|
||||||
private SparkTransformServer sparkTransformServer = null;
|
|
||||||
private TransformDataType transformDataType = null;
|
|
||||||
|
|
||||||
public void runMain(String[] args) throws Exception {
|
|
||||||
|
|
||||||
int pos = getMatchingPosition(args, "-dt", "--dataType");
|
|
||||||
if (pos == -1) {
|
|
||||||
log.error("no valid options");
|
|
||||||
log.error("-dt, --dataType Options: [CSV, IMAGE]");
|
|
||||||
throw new Exception("no valid options");
|
|
||||||
} else {
|
|
||||||
transformDataType = TransformDataType.valueOf(args[pos + 1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
switch (transformDataType) {
|
|
||||||
case CSV:
|
|
||||||
sparkTransformServer = new CSVSparkTransformServer();
|
|
||||||
break;
|
|
||||||
case IMAGE:
|
|
||||||
sparkTransformServer = new ImageSparkTransformServer();
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
throw new InvalidClassException("no matching SparkTransform class");
|
|
||||||
}
|
|
||||||
|
|
||||||
sparkTransformServer.runMain(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
private int getMatchingPosition(String[] args, String... options) {
|
|
||||||
List optionList = Arrays.asList(options);
|
|
||||||
|
|
||||||
for (int i = 0; i < args.length; i++) {
|
|
||||||
if (optionList.contains(args[i])) {
|
|
||||||
return i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
new SparkTransformServerChooser().runMain(args);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.inference.server;
|
|
||||||
|
|
||||||
public enum TransformDataType {
|
|
||||||
CSV, IMAGE,
|
|
||||||
}
|
|
|
@ -1,350 +0,0 @@
|
||||||
# This is the main configuration file for the application.
|
|
||||||
# https://www.playframework.com/documentation/latest/ConfigFile
|
|
||||||
# ~~~~~
|
|
||||||
# Play uses HOCON as its configuration file format. HOCON has a number
|
|
||||||
# of advantages over other config formats, but there are two things that
|
|
||||||
# can be used when modifying settings.
|
|
||||||
#
|
|
||||||
# You can include other configuration files in this main application.conf file:
|
|
||||||
#include "extra-config.conf"
|
|
||||||
#
|
|
||||||
# You can declare variables and substitute for them:
|
|
||||||
#mykey = ${some.value}
|
|
||||||
#
|
|
||||||
# And if an environment variable exists when there is no other subsitution, then
|
|
||||||
# HOCON will fall back to substituting environment variable:
|
|
||||||
#mykey = ${JAVA_HOME}
|
|
||||||
|
|
||||||
## Akka
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
|
|
||||||
# ~~~~~
|
|
||||||
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
|
|
||||||
# other streaming HTTP responses.
|
|
||||||
akka {
|
|
||||||
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
|
|
||||||
# configuration at INFO level, including defaults and overrides, so it s worth
|
|
||||||
# putting at the very top.
|
|
||||||
#
|
|
||||||
# Put the following in your conf/logback.xml file:
|
|
||||||
#
|
|
||||||
# <logger name="akka.actor" level="INFO" />
|
|
||||||
#
|
|
||||||
# And then uncomment this line to debug the configuration.
|
|
||||||
#
|
|
||||||
#log-config-on-start = true
|
|
||||||
}
|
|
||||||
|
|
||||||
## Modules
|
|
||||||
# https://www.playframework.com/documentation/latest/Modules
|
|
||||||
# ~~~~~
|
|
||||||
# Control which modules are loaded when Play starts. Note that modules are
|
|
||||||
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
|
|
||||||
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
|
|
||||||
# for more information.
|
|
||||||
#
|
|
||||||
# You can also extend Play functionality by using one of the publically available
|
|
||||||
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
|
|
||||||
play.modules {
|
|
||||||
# By default, Play will load any class called Module that is defined
|
|
||||||
# in the root package (the "app" directory), or you can define them
|
|
||||||
# explicitly below.
|
|
||||||
# If there are any built-in modules that you want to disable, you can list them here.
|
|
||||||
#enabled += my.application.Module
|
|
||||||
|
|
||||||
# If there are any built-in modules that you want to disable, you can list them here.
|
|
||||||
#disabled += ""
|
|
||||||
}
|
|
||||||
|
|
||||||
## Internationalisation
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaI18N
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaI18N
|
|
||||||
# ~~~~~
|
|
||||||
# Play comes with its own i18n settings, which allow the user's preferred language
|
|
||||||
# to map through to internal messages, or allow the language to be stored in a cookie.
|
|
||||||
play.i18n {
|
|
||||||
# The application languages
|
|
||||||
langs = [ "en" ]
|
|
||||||
|
|
||||||
# Whether the language cookie should be secure or not
|
|
||||||
#langCookieSecure = true
|
|
||||||
|
|
||||||
# Whether the HTTP only attribute of the cookie should be set to true
|
|
||||||
#langCookieHttpOnly = true
|
|
||||||
}
|
|
||||||
|
|
||||||
## Play HTTP settings
|
|
||||||
# ~~~~~
|
|
||||||
play.http {
|
|
||||||
## Router
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
|
||||||
# ~~~~~
|
|
||||||
# Define the Router object to use for this application.
|
|
||||||
# This router will be looked up first when the application is starting up,
|
|
||||||
# so make sure this is the entry point.
|
|
||||||
# Furthermore, it's assumed your route file is named properly.
|
|
||||||
# So for an application router like `my.application.Router`,
|
|
||||||
# you may need to define a router file `conf/my.application.routes`.
|
|
||||||
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
|
|
||||||
#router = my.application.Router
|
|
||||||
|
|
||||||
## Action Creator
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaActionCreator
|
|
||||||
# ~~~~~
|
|
||||||
#actionCreator = null
|
|
||||||
|
|
||||||
## ErrorHandler
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
|
||||||
# ~~~~~
|
|
||||||
# If null, will attempt to load a class called ErrorHandler in the root package,
|
|
||||||
#errorHandler = null
|
|
||||||
|
|
||||||
## Filters
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaHttpFilters
|
|
||||||
# ~~~~~
|
|
||||||
# Filters run code on every request. They can be used to perform
|
|
||||||
# common logic for all your actions, e.g. adding common headers.
|
|
||||||
# Defaults to "Filters" in the root package (aka "apps" folder)
|
|
||||||
# Alternatively you can explicitly register a class here.
|
|
||||||
#filters += my.application.Filters
|
|
||||||
|
|
||||||
## Session & Flash
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaSessionFlash
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
|
|
||||||
# ~~~~~
|
|
||||||
session {
|
|
||||||
# Sets the cookie to be sent only over HTTPS.
|
|
||||||
#secure = true
|
|
||||||
|
|
||||||
# Sets the cookie to be accessed only by the server.
|
|
||||||
#httpOnly = true
|
|
||||||
|
|
||||||
# Sets the max-age field of the cookie to 5 minutes.
|
|
||||||
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
|
|
||||||
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
|
|
||||||
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
|
|
||||||
#maxAge = 300
|
|
||||||
|
|
||||||
# Sets the domain on the session cookie.
|
|
||||||
#domain = "example.com"
|
|
||||||
}
|
|
||||||
|
|
||||||
flash {
|
|
||||||
# Sets the cookie to be sent only over HTTPS.
|
|
||||||
#secure = true
|
|
||||||
|
|
||||||
# Sets the cookie to be accessed only by the server.
|
|
||||||
#httpOnly = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
## Netty Provider
|
|
||||||
# https://www.playframework.com/documentation/latest/SettingsNetty
|
|
||||||
# ~~~~~
|
|
||||||
play.server.netty {
|
|
||||||
# Whether the Netty wire should be logged
|
|
||||||
#log.wire = true
|
|
||||||
|
|
||||||
# If you run Play on Linux, you can use Netty's native socket transport
|
|
||||||
# for higher performance with less garbage.
|
|
||||||
#transport = "native"
|
|
||||||
}
|
|
||||||
|
|
||||||
## WS (HTTP Client)
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
|
|
||||||
# ~~~~~
|
|
||||||
# The HTTP client primarily used for REST APIs. The default client can be
|
|
||||||
# configured directly, but you can also create different client instances
|
|
||||||
# with customized settings. You must enable this by adding to build.sbt:
|
|
||||||
#
|
|
||||||
# libraryDependencies += ws // or javaWs if using java
|
|
||||||
#
|
|
||||||
play.ws {
|
|
||||||
# Sets HTTP requests not to follow 302 requests
|
|
||||||
#followRedirects = false
|
|
||||||
|
|
||||||
# Sets the maximum number of open HTTP connections for the client.
|
|
||||||
#ahc.maxConnectionsTotal = 50
|
|
||||||
|
|
||||||
## WS SSL
|
|
||||||
# https://www.playframework.com/documentation/latest/WsSSL
|
|
||||||
# ~~~~~
|
|
||||||
ssl {
|
|
||||||
# Configuring HTTPS with Play WS does not require programming. You can
|
|
||||||
# set up both trustManager and keyManager for mutual authentication, and
|
|
||||||
# turn on JSSE debugging in development with a reload.
|
|
||||||
#debug.handshake = true
|
|
||||||
#trustManager = {
|
|
||||||
# stores = [
|
|
||||||
# { type = "JKS", path = "exampletrust.jks" }
|
|
||||||
# ]
|
|
||||||
#}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
## Cache
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaCache
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaCache
|
|
||||||
# ~~~~~
|
|
||||||
# Play comes with an integrated cache API that can reduce the operational
|
|
||||||
# overhead of repeated requests. You must enable this by adding to build.sbt:
|
|
||||||
#
|
|
||||||
# libraryDependencies += cache
|
|
||||||
#
|
|
||||||
play.cache {
|
|
||||||
# If you want to bind several caches, you can bind the individually
|
|
||||||
#bindCaches = ["db-cache", "user-cache", "session-cache"]
|
|
||||||
}
|
|
||||||
|
|
||||||
## Filters
|
|
||||||
# https://www.playframework.com/documentation/latest/Filters
|
|
||||||
# ~~~~~
|
|
||||||
# There are a number of built-in filters that can be enabled and configured
|
|
||||||
# to give Play greater security. You must enable this by adding to build.sbt:
|
|
||||||
#
|
|
||||||
# libraryDependencies += filters
|
|
||||||
#
|
|
||||||
play.filters {
|
|
||||||
## CORS filter configuration
|
|
||||||
# https://www.playframework.com/documentation/latest/CorsFilter
|
|
||||||
# ~~~~~
|
|
||||||
# CORS is a protocol that allows web applications to make requests from the browser
|
|
||||||
# across different domains.
|
|
||||||
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
|
|
||||||
# dependencies on CORS settings.
|
|
||||||
cors {
|
|
||||||
# Filter paths by a whitelist of path prefixes
|
|
||||||
#pathPrefixes = ["/some/path", ...]
|
|
||||||
|
|
||||||
# The allowed origins. If null, all origins are allowed.
|
|
||||||
#allowedOrigins = ["http://www.example.com"]
|
|
||||||
|
|
||||||
# The allowed HTTP methods. If null, all methods are allowed
|
|
||||||
#allowedHttpMethods = ["GET", "POST"]
|
|
||||||
}
|
|
||||||
|
|
||||||
## CSRF Filter
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
|
|
||||||
# ~~~~~
|
|
||||||
# Play supports multiple methods for verifying that a request is not a CSRF request.
|
|
||||||
# The primary mechanism is a CSRF token. This token gets placed either in the query string
|
|
||||||
# or body of every form submitted, and also gets placed in the users session.
|
|
||||||
# Play then verifies that both tokens are present and match.
|
|
||||||
csrf {
|
|
||||||
# Sets the cookie to be sent only over HTTPS
|
|
||||||
#cookie.secure = true
|
|
||||||
|
|
||||||
# Defaults to CSRFErrorHandler in the root package.
|
|
||||||
#errorHandler = MyCSRFErrorHandler
|
|
||||||
}
|
|
||||||
|
|
||||||
## Security headers filter configuration
|
|
||||||
# https://www.playframework.com/documentation/latest/SecurityHeaders
|
|
||||||
# ~~~~~
|
|
||||||
# Defines security headers that prevent XSS attacks.
|
|
||||||
# If enabled, then all options are set to the below configuration by default:
|
|
||||||
headers {
|
|
||||||
# The X-Frame-Options header. If null, the header is not set.
|
|
||||||
#frameOptions = "DENY"
|
|
||||||
|
|
||||||
# The X-XSS-Protection header. If null, the header is not set.
|
|
||||||
#xssProtection = "1; mode=block"
|
|
||||||
|
|
||||||
# The X-Content-Type-Options header. If null, the header is not set.
|
|
||||||
#contentTypeOptions = "nosniff"
|
|
||||||
|
|
||||||
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
|
|
||||||
#permittedCrossDomainPolicies = "master-only"
|
|
||||||
|
|
||||||
# The Content-Security-Policy header. If null, the header is not set.
|
|
||||||
#contentSecurityPolicy = "default-src 'self'"
|
|
||||||
}
|
|
||||||
|
|
||||||
## Allowed hosts filter configuration
|
|
||||||
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
|
|
||||||
# ~~~~~
|
|
||||||
# Play provides a filter that lets you configure which hosts can access your application.
|
|
||||||
# This is useful to prevent cache poisoning attacks.
|
|
||||||
hosts {
|
|
||||||
# Allow requests to example.com, its subdomains, and localhost:9000.
|
|
||||||
#allowed = [".example.com", "localhost:9000"]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
## Evolutions
|
|
||||||
# https://www.playframework.com/documentation/latest/Evolutions
|
|
||||||
# ~~~~~
|
|
||||||
# Evolutions allows database scripts to be automatically run on startup in dev mode
|
|
||||||
# for database migrations. You must enable this by adding to build.sbt:
|
|
||||||
#
|
|
||||||
# libraryDependencies += evolutions
|
|
||||||
#
|
|
||||||
play.evolutions {
|
|
||||||
# You can disable evolutions for a specific datasource if necessary
|
|
||||||
#db.default.enabled = false
|
|
||||||
}
|
|
||||||
|
|
||||||
## Database Connection Pool
|
|
||||||
# https://www.playframework.com/documentation/latest/SettingsJDBC
|
|
||||||
# ~~~~~
|
|
||||||
# Play doesn't require a JDBC database to run, but you can easily enable one.
|
|
||||||
#
|
|
||||||
# libraryDependencies += jdbc
|
|
||||||
#
|
|
||||||
play.db {
|
|
||||||
# The combination of these two settings results in "db.default" as the
|
|
||||||
# default JDBC pool:
|
|
||||||
#config = "db"
|
|
||||||
#default = "default"
|
|
||||||
|
|
||||||
# Play uses HikariCP as the default connection pool. You can override
|
|
||||||
# settings by changing the prototype:
|
|
||||||
prototype {
|
|
||||||
# Sets a fixed JDBC connection pool size of 50
|
|
||||||
#hikaricp.minimumIdle = 50
|
|
||||||
#hikaricp.maximumPoolSize = 50
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
## JDBC Datasource
|
|
||||||
# https://www.playframework.com/documentation/latest/JavaDatabase
|
|
||||||
# https://www.playframework.com/documentation/latest/ScalaDatabase
|
|
||||||
# ~~~~~
|
|
||||||
# Once JDBC datasource is set up, you can work with several different
|
|
||||||
# database options:
|
|
||||||
#
|
|
||||||
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
|
|
||||||
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
|
|
||||||
# EBean: https://playframework.com/documentation/latest/JavaEbean
|
|
||||||
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
|
|
||||||
#
|
|
||||||
db {
|
|
||||||
# You can declare as many datasources as you want.
|
|
||||||
# By convention, the default datasource is named `default`
|
|
||||||
|
|
||||||
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
|
|
||||||
default.driver = org.h2.Driver
|
|
||||||
default.url = "jdbc:h2:mem:play"
|
|
||||||
#default.username = sa
|
|
||||||
#default.password = ""
|
|
||||||
|
|
||||||
# You can expose this datasource via JNDI if needed (Useful for JPA)
|
|
||||||
default.jndiName=DefaultDS
|
|
||||||
|
|
||||||
# You can turn on SQL logging for any datasource
|
|
||||||
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
|
|
||||||
#default.logSql=true
|
|
||||||
}
|
|
||||||
|
|
||||||
jpa.default=defaultPersistenceUnit
|
|
||||||
|
|
||||||
|
|
||||||
#Increase default maximum post length - used for remote listener functionality
|
|
||||||
#Can get response 413 with larger networks without setting this
|
|
||||||
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
|
|
||||||
#parsers.text.maxLength=10M
|
|
||||||
play.http.parser.maxMemoryBuffer=10M
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
|
||||||
import org.nd4j.common.tests.BaseND4JTest;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Set<Class<?>> getExclusions() {
|
|
||||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
|
||||||
return new HashSet<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected String getPackageName() {
|
|
||||||
return "org.datavec.spark.transform";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Class<?> getBaseClass() {
|
|
||||||
return BaseND4JTest.class;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,127 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.JsonNode;
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
import static org.junit.Assume.assumeNotNull;
|
|
||||||
|
|
||||||
public class CSVSparkTransformServerNoJsonTest {
|
|
||||||
|
|
||||||
private static CSVSparkTransformServer server;
|
|
||||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
private static TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
|
||||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void before() throws Exception {
|
|
||||||
server = new CSVSparkTransformServer();
|
|
||||||
FileUtils.write(fileSave, transformProcess.toJson());
|
|
||||||
|
|
||||||
// Only one time
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
server.runMain(new String[] {"-dp", "9050"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void after() throws Exception {
|
|
||||||
fileSave.delete();
|
|
||||||
server.stop();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testServer() throws Exception {
|
|
||||||
assertTrue(server.getTransform() == null);
|
|
||||||
JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(transformProcess.toJson()).asJson().getBody();
|
|
||||||
assumeNotNull(server.getTransform());
|
|
||||||
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
|
||||||
JsonNode jsonNode =
|
|
||||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
|
||||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(SingleCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < 3; i++)
|
|
||||||
batchCSVRecord.add(singleCsvRecord);
|
|
||||||
/* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
*/
|
|
||||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,121 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.JsonNode;
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
public class CSVSparkTransformServerTest {
|
|
||||||
|
|
||||||
private static CSVSparkTransformServer server;
|
|
||||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
private static TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
|
||||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void before() throws Exception {
|
|
||||||
server = new CSVSparkTransformServer();
|
|
||||||
FileUtils.write(fileSave, transformProcess.toJson());
|
|
||||||
// Only one time
|
|
||||||
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void after() throws Exception {
|
|
||||||
fileSave.deleteOnExit();
|
|
||||||
server.stop();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testServer() throws Exception {
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
|
||||||
JsonNode jsonNode =
|
|
||||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
|
||||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(SingleCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < 3; i++)
|
|
||||||
batchCSVRecord.add(singleCsvRecord);
|
|
||||||
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,164 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.JsonNode;
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.server.ImageSparkTransformServer;
|
|
||||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
|
||||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
|
||||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class ImageSparkTransformServerTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
private static ImageSparkTransformServer server;
|
|
||||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void before() throws Exception {
|
|
||||||
server = new ImageSparkTransformServer();
|
|
||||||
|
|
||||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
|
|
||||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
|
||||||
|
|
||||||
FileUtils.write(fileSave, imgTransformProcess.toJson());
|
|
||||||
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"});
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void after() throws Exception {
|
|
||||||
fileSave.deleteOnExit();
|
|
||||||
server.stop();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testImageServer() throws Exception {
|
|
||||||
SingleImageRecord record =
|
|
||||||
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
|
||||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asJson().getBody();
|
|
||||||
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
BatchImageRecord batch = new BatchImageRecord();
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
|
|
||||||
|
|
||||||
JsonNode jsonNodeBatch =
|
|
||||||
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(batch).asJson().getBody();
|
|
||||||
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
INDArray result = getNDArray(jsonNode);
|
|
||||||
assertEquals(1, result.size(0));
|
|
||||||
|
|
||||||
INDArray batchResult = getNDArray(jsonNodeBatch);
|
|
||||||
assertEquals(3, batchResult.size(0));
|
|
||||||
|
|
||||||
// System.out.println(array);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testImageServerMultipart() throws Exception {
|
|
||||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile())
|
|
||||||
.field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile())
|
|
||||||
.field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile())
|
|
||||||
.asJson().getBody();
|
|
||||||
|
|
||||||
|
|
||||||
INDArray batchResult = getNDArray(jsonNode);
|
|
||||||
assertEquals(3, batchResult.size(0));
|
|
||||||
|
|
||||||
// System.out.println(batchResult);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testImageServerSingleMultipart() throws Exception {
|
|
||||||
File f = testDir.newFolder();
|
|
||||||
File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f);
|
|
||||||
|
|
||||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
|
|
||||||
.header("accept", "application/json")
|
|
||||||
.field("file1", imgFile)
|
|
||||||
.asJson().getBody();
|
|
||||||
|
|
||||||
|
|
||||||
INDArray result = getNDArray(jsonNode);
|
|
||||||
assertEquals(1, result.size(0));
|
|
||||||
|
|
||||||
// System.out.println(result);
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getNDArray(JsonNode node) throws IOException {
|
|
||||||
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,168 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.datavec.spark.transform;
|
|
||||||
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.JsonNode;
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.transform.TransformProcess;
|
|
||||||
import org.datavec.api.transform.schema.Schema;
|
|
||||||
import org.datavec.image.transform.ImageTransformProcess;
|
|
||||||
import org.datavec.spark.inference.server.SparkTransformServerChooser;
|
|
||||||
import org.datavec.spark.inference.server.TransformDataType;
|
|
||||||
import org.datavec.spark.inference.model.model.*;
|
|
||||||
import org.junit.AfterClass;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.io.ClassPathResource;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class SparkTransformServerTest {
|
|
||||||
private static SparkTransformServerChooser serverChooser;
|
|
||||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
|
||||||
private static TransformProcess transformProcess =
|
|
||||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build();
|
|
||||||
|
|
||||||
private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json");
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void before() throws Exception {
|
|
||||||
serverChooser = new SparkTransformServerChooser();
|
|
||||||
|
|
||||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
|
|
||||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
|
||||||
|
|
||||||
FileUtils.write(imageTransformFile, imgTransformProcess.toJson());
|
|
||||||
|
|
||||||
FileUtils.write(csvTransformFile, transformProcess.toJson());
|
|
||||||
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@AfterClass
|
|
||||||
public static void after() throws Exception {
|
|
||||||
imageTransformFile.deleteOnExit();
|
|
||||||
csvTransformFile.deleteOnExit();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testImageServer() throws Exception {
|
|
||||||
serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt",
|
|
||||||
TransformDataType.IMAGE.toString()});
|
|
||||||
|
|
||||||
SingleImageRecord record =
|
|
||||||
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
|
||||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asJson().getBody();
|
|
||||||
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
BatchImageRecord batch = new BatchImageRecord();
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
|
|
||||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
|
|
||||||
|
|
||||||
JsonNode jsonNodeBatch =
|
|
||||||
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(batch).asJson().getBody();
|
|
||||||
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
INDArray result = getNDArray(jsonNode);
|
|
||||||
assertEquals(1, result.size(0));
|
|
||||||
|
|
||||||
INDArray batchResult = getNDArray(jsonNodeBatch);
|
|
||||||
assertEquals(3, batchResult.size(0));
|
|
||||||
|
|
||||||
serverChooser.getSparkTransformServer().stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testCSVServer() throws Exception {
|
|
||||||
serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt",
|
|
||||||
TransformDataType.CSV.toString()});
|
|
||||||
|
|
||||||
String[] values = new String[] {"1.0", "2.0"};
|
|
||||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
|
||||||
JsonNode jsonNode =
|
|
||||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
|
||||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(SingleCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
|
||||||
for (int i = 0; i < 3; i++)
|
|
||||||
batchCSVRecord.add(singleCsvRecord);
|
|
||||||
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
|
||||||
|
|
||||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
|
||||||
.asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
|
||||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
|
||||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
|
||||||
|
|
||||||
|
|
||||||
serverChooser.getSparkTransformServer().stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getNDArray(JsonNode node) throws IOException {
|
|
||||||
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,6 +0,0 @@
|
||||||
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
|
|
||||||
play.modules.enabled += io.skymind.skil.service.PredictionModule
|
|
||||||
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
|
|
||||||
play.server.pidfile.path=/tmp/RUNNING_PID
|
|
||||||
|
|
||||||
play.server.http.port = 9600
|
|
|
@ -1,68 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
|
||||||
<packaging>pom</packaging>
|
|
||||||
|
|
||||||
<name>datavec-spark-inference-parent</name>
|
|
||||||
|
|
||||||
<modules>
|
|
||||||
<module>datavec-spark-inference-server</module>
|
|
||||||
<module>datavec-spark-inference-client</module>
|
|
||||||
<module>datavec-spark-inference-model</module>
|
|
||||||
</modules>
|
|
||||||
|
|
||||||
<dependencyManagement>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
<version>${datavec.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.mashape.unirest</groupId>
|
|
||||||
<artifactId>unirest-java</artifactId>
|
|
||||||
<version>${unirest.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</dependencyManagement>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -45,7 +45,6 @@
|
||||||
<module>datavec-data</module>
|
<module>datavec-data</module>
|
||||||
<module>datavec-spark</module>
|
<module>datavec-spark</module>
|
||||||
<module>datavec-local</module>
|
<module>datavec-local</module>
|
||||||
<module>datavec-spark-inference-parent</module>
|
|
||||||
<module>datavec-jdbc</module>
|
<module>datavec-jdbc</module>
|
||||||
<module>datavec-excel</module>
|
<module>datavec-excel</module>
|
||||||
<module>datavec-arrow</module>
|
<module>datavec-arrow</module>
|
||||||
|
|
|
@ -1,143 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-nearestneighbor-server</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<java.compile.version>1.8</java.compile.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-core</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>io.vertx</groupId>
|
|
||||||
<artifactId>vertx-core</artifactId>
|
|
||||||
<version>${vertx.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>io.vertx</groupId>
|
|
||||||
<artifactId>vertx-web</artifactId>
|
|
||||||
<version>${vertx.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.mashape.unirest</groupId>
|
|
||||||
<artifactId>unirest-java</artifactId>
|
|
||||||
<version>${unirest.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.beust</groupId>
|
|
||||||
<artifactId>jcommander</artifactId>
|
|
||||||
<version>${jcommander.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>ch.qos.logback</groupId>
|
|
||||||
<artifactId>logback-classic</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<argLine>-Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
|
||||||
<includes>
|
|
||||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
|
||||||
<include>*.java</include>
|
|
||||||
<include>**/*.java</include>
|
|
||||||
</includes>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.maven.plugins</groupId>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<source>${java.compile.version}</source>
|
|
||||||
<target>${java.compile.version}</target>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,67 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.server;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
public class NearestNeighbor {
|
|
||||||
private NearestNeighborRequest record;
|
|
||||||
private VPTree tree;
|
|
||||||
private INDArray points;
|
|
||||||
|
|
||||||
public List<NearestNeighborsResult> search() {
|
|
||||||
INDArray input = points.slice(record.getInputIndex());
|
|
||||||
List<NearestNeighborsResult> results = new ArrayList<>();
|
|
||||||
if (input.isVector()) {
|
|
||||||
List<DataPoint> add = new ArrayList<>();
|
|
||||||
List<Double> distances = new ArrayList<>();
|
|
||||||
tree.search(input, record.getK(), add, distances);
|
|
||||||
|
|
||||||
if (add.size() != distances.size()) {
|
|
||||||
throw new IllegalStateException(
|
|
||||||
String.format("add.size == %d != %d == distances.size",
|
|
||||||
add.size(), distances.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i=0; i<add.size(); i++) {
|
|
||||||
results.add(new NearestNeighborsResult(add.get(i).getIndex(), distances.get(i)));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return results;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,278 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.server;
|
|
||||||
|
|
||||||
import com.beust.jcommander.JCommander;
|
|
||||||
import com.beust.jcommander.Parameter;
|
|
||||||
import com.beust.jcommander.ParameterException;
|
|
||||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
|
||||||
import io.vertx.core.AbstractVerticle;
|
|
||||||
import io.vertx.core.Vertx;
|
|
||||||
import io.vertx.ext.web.Router;
|
|
||||||
import io.vertx.ext.web.handler.BodyHandler;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
|
||||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.*;
|
|
||||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.shape.Shape;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
import org.nd4j.serde.binary.BinarySerde;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
public class NearestNeighborsServer extends AbstractVerticle {
|
|
||||||
|
|
||||||
private static class RunArgs {
|
|
||||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
|
||||||
private String ndarrayPath = null;
|
|
||||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
|
||||||
private String labelsPath = null;
|
|
||||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
|
||||||
private int port = 9000;
|
|
||||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
|
||||||
private String similarityFunction = "euclidean";
|
|
||||||
@Parameter(names = {"--invert"}, arity = 1)
|
|
||||||
private boolean invert = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static RunArgs instanceArgs;
|
|
||||||
private static NearestNeighborsServer instance;
|
|
||||||
|
|
||||||
public NearestNeighborsServer(){ }
|
|
||||||
|
|
||||||
public static NearestNeighborsServer getInstance(){
|
|
||||||
return instance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void runMain(String... args) {
|
|
||||||
RunArgs r = new RunArgs();
|
|
||||||
JCommander jcmdr = new JCommander(r);
|
|
||||||
|
|
||||||
try {
|
|
||||||
jcmdr.parse(args);
|
|
||||||
} catch (ParameterException e) {
|
|
||||||
log.error("Error in NearestNeighboursServer parameters", e);
|
|
||||||
StringBuilder sb = new StringBuilder();
|
|
||||||
jcmdr.usage(sb);
|
|
||||||
log.error("Usage: {}", sb.toString());
|
|
||||||
|
|
||||||
//User provides invalid input -> print the usage info
|
|
||||||
jcmdr.usage();
|
|
||||||
if (r.ndarrayPath == null)
|
|
||||||
log.error("Json path parameter is missing (null)");
|
|
||||||
try {
|
|
||||||
Thread.sleep(500);
|
|
||||||
} catch (Exception e2) {
|
|
||||||
}
|
|
||||||
System.exit(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
instanceArgs = r;
|
|
||||||
try {
|
|
||||||
Vertx vertx = Vertx.vertx();
|
|
||||||
vertx.deployVerticle(NearestNeighborsServer.class.getName());
|
|
||||||
} catch (Throwable t){
|
|
||||||
log.error("Error in NearestNeighboursServer run method",t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void start() throws Exception {
|
|
||||||
instance = this;
|
|
||||||
|
|
||||||
String[] pathArr = instanceArgs.ndarrayPath.split(",");
|
|
||||||
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
|
||||||
// first of all we reading shapes of saved eariler files
|
|
||||||
int rows = 0;
|
|
||||||
int cols = 0;
|
|
||||||
for (int i = 0; i < pathArr.length; i++) {
|
|
||||||
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
|
||||||
|
|
||||||
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
|
||||||
Shape.size(shape, 1));
|
|
||||||
|
|
||||||
if (Shape.rank(shape) != 2)
|
|
||||||
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
|
||||||
|
|
||||||
rows += Shape.size(shape, 0);
|
|
||||||
|
|
||||||
if (cols == 0)
|
|
||||||
cols = Shape.size(shape, 1);
|
|
||||||
else if (cols != Shape.size(shape, 1))
|
|
||||||
throw new DL4JInvalidInputException(
|
|
||||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
|
||||||
}
|
|
||||||
|
|
||||||
final List<String> labels = new ArrayList<>();
|
|
||||||
if (instanceArgs.labelsPath != null) {
|
|
||||||
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
|
|
||||||
for (int i = 0; i < labelsPathArr.length; i++) {
|
|
||||||
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (!labels.isEmpty() && labels.size() != rows)
|
|
||||||
throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size()));
|
|
||||||
|
|
||||||
final INDArray points = Nd4j.createUninitialized(rows, cols);
|
|
||||||
|
|
||||||
int lastPosition = 0;
|
|
||||||
for (int i = 0; i < pathArr.length; i++) {
|
|
||||||
log.info("Loading chunk {} of {}", i + 1, pathArr.length);
|
|
||||||
INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i]));
|
|
||||||
|
|
||||||
points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr);
|
|
||||||
lastPosition += pointsArr.rows();
|
|
||||||
|
|
||||||
// let's ensure we don't bring too much stuff in next loop
|
|
||||||
System.gc();
|
|
||||||
}
|
|
||||||
|
|
||||||
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
|
|
||||||
|
|
||||||
//Set play secret key, if required
|
|
||||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
|
||||||
String crypto = System.getProperty("play.crypto.secret");
|
|
||||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
|
|
||||||
byte[] newCrypto = new byte[1024];
|
|
||||||
|
|
||||||
new Random().nextBytes(newCrypto);
|
|
||||||
|
|
||||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
|
||||||
System.setProperty("play.crypto.secret", base64);
|
|
||||||
}
|
|
||||||
|
|
||||||
Router r = Router.router(vertx);
|
|
||||||
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
|
|
||||||
createRoutes(r, labels, tree, points);
|
|
||||||
|
|
||||||
vertx.createHttpServer()
|
|
||||||
.requestHandler(r)
|
|
||||||
.listen(instanceArgs.port);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
|
|
||||||
|
|
||||||
r.post("/knn").handler(rc -> {
|
|
||||||
try {
|
|
||||||
String json = rc.getBodyAsJson().encode();
|
|
||||||
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
|
|
||||||
|
|
||||||
NearestNeighbor nearestNeighbor =
|
|
||||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
|
||||||
|
|
||||||
if (record == null) {
|
|
||||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
|
||||||
.putHeader("content-type", "application/json")
|
|
||||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
|
||||||
|
|
||||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
|
||||||
.putHeader("content-type", "application/json")
|
|
||||||
.end(JsonMappers.getMapper().writeValueAsString(results));
|
|
||||||
return;
|
|
||||||
} catch (Throwable e) {
|
|
||||||
log.error("Error in POST /knn",e);
|
|
||||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
|
||||||
.end("Error parsing request - " + e.getMessage());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
r.post("/knnnew").handler(rc -> {
|
|
||||||
try {
|
|
||||||
String json = rc.getBodyAsJson().encode();
|
|
||||||
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
|
|
||||||
if (record == null) {
|
|
||||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
|
||||||
.putHeader("content-type", "application/json")
|
|
||||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
|
||||||
List<DataPoint> results;
|
|
||||||
List<Double> distances;
|
|
||||||
|
|
||||||
if (record.isForceFillK()) {
|
|
||||||
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr);
|
|
||||||
vpTreeFillSearch.search();
|
|
||||||
results = vpTreeFillSearch.getResults();
|
|
||||||
distances = vpTreeFillSearch.getDistances();
|
|
||||||
} else {
|
|
||||||
results = new ArrayList<>();
|
|
||||||
distances = new ArrayList<>();
|
|
||||||
tree.search(arr, record.getK(), results, distances);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (results.size() != distances.size()) {
|
|
||||||
rc.response()
|
|
||||||
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
|
||||||
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
|
||||||
for (int i=0; i<results.size(); i++) {
|
|
||||||
if (!labels.isEmpty())
|
|
||||||
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i), labels.get(results.get(i).getIndex())));
|
|
||||||
else
|
|
||||||
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
|
||||||
String j = JsonMappers.getMapper().writeValueAsString(results2);
|
|
||||||
rc.response()
|
|
||||||
.putHeader("content-type", "application/json")
|
|
||||||
.end(j);
|
|
||||||
} catch (Throwable e) {
|
|
||||||
log.error("Error in POST /knnnew",e);
|
|
||||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
|
||||||
.end("Error parsing request - " + e.getMessage());
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Stop the server
|
|
||||||
*/
|
|
||||||
public void stop() throws Exception {
|
|
||||||
super.stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void main(String[] args) throws Exception {
|
|
||||||
runMain(args);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,161 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.server;
|
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
|
||||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
|
||||||
import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults;
|
|
||||||
import org.junit.Rule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.serde.binary.BinarySerde;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.net.ServerSocket;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.Executor;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class NearestNeighborTest extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Rule
|
|
||||||
public TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNearestNeighbor() {
|
|
||||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
|
||||||
INDArray arr = Nd4j.create(data);
|
|
||||||
|
|
||||||
VPTree vpTree = new VPTree(arr, false);
|
|
||||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
|
||||||
request.setK(2);
|
|
||||||
request.setInputIndex(0);
|
|
||||||
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
|
|
||||||
List<NearestNeighborsResult> results = nearestNeighbor.search();
|
|
||||||
assertEquals(1, results.get(0).getIndex());
|
|
||||||
assertEquals(2, results.size());
|
|
||||||
|
|
||||||
assertEquals(1.0, results.get(0).getDistance(), 1e-4);
|
|
||||||
assertEquals(4.0, results.get(1).getDistance(), 1e-4);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNearestNeighborInverted() {
|
|
||||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
|
||||||
INDArray arr = Nd4j.create(data);
|
|
||||||
|
|
||||||
VPTree vpTree = new VPTree(arr, true);
|
|
||||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
|
||||||
request.setK(2);
|
|
||||||
request.setInputIndex(0);
|
|
||||||
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
|
|
||||||
List<NearestNeighborsResult> results = nearestNeighbor.search();
|
|
||||||
assertEquals(2, results.get(0).getIndex());
|
|
||||||
assertEquals(2, results.size());
|
|
||||||
|
|
||||||
assertEquals(-4.0, results.get(0).getDistance(), 1e-4);
|
|
||||||
assertEquals(-1.0, results.get(1).getDistance(), 1e-4);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void vpTreeTest() throws Exception {
|
|
||||||
INDArray matrix = Nd4j.rand(new int[] {400,10});
|
|
||||||
INDArray rowVector = matrix.getRow(70);
|
|
||||||
INDArray resultArr = Nd4j.zeros(400,1);
|
|
||||||
Executor executor = Executors.newSingleThreadExecutor();
|
|
||||||
VPTree vpTree = new VPTree(matrix);
|
|
||||||
System.out.println("Ran!");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public static int getAvailablePort() {
|
|
||||||
try {
|
|
||||||
ServerSocket socket = new ServerSocket(0);
|
|
||||||
try {
|
|
||||||
return socket.getLocalPort();
|
|
||||||
} finally {
|
|
||||||
socket.close();
|
|
||||||
}
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testServer() throws Exception {
|
|
||||||
int localPort = getAvailablePort();
|
|
||||||
Nd4j.getRandom().setSeed(7);
|
|
||||||
INDArray rand = Nd4j.randn(10, 5);
|
|
||||||
File writeToTmp = testDir.newFile();
|
|
||||||
writeToTmp.deleteOnExit();
|
|
||||||
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
|
||||||
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
|
||||||
String.valueOf(localPort));
|
|
||||||
|
|
||||||
Thread.sleep(3000);
|
|
||||||
|
|
||||||
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
|
||||||
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
|
||||||
assertEquals(5, result.getResults().size());
|
|
||||||
NearestNeighborsServer.getInstance().stop();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testFullSearch() throws Exception {
|
|
||||||
int numRows = 1000;
|
|
||||||
int numCols = 100;
|
|
||||||
int numNeighbors = 42;
|
|
||||||
INDArray points = Nd4j.rand(numRows, numCols);
|
|
||||||
VPTree tree = new VPTree(points);
|
|
||||||
INDArray query = Nd4j.rand(new int[] {1, numCols});
|
|
||||||
VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query);
|
|
||||||
fillSearch.search();
|
|
||||||
List<DataPoint> results = fillSearch.getResults();
|
|
||||||
List<Double> distances = fillSearch.getDistances();
|
|
||||||
assertEquals(numNeighbors, distances.size());
|
|
||||||
assertEquals(numNeighbors, results.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDistances() {
|
|
||||||
|
|
||||||
INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}});
|
|
||||||
INDArray record = Nd4j.create(new float[][]{{7, 6}});
|
|
||||||
VPTree vpTree = new VPTree(indArray, "euclidean", false);
|
|
||||||
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record);
|
|
||||||
vpTreeFillSearch.search();
|
|
||||||
//System.out.println(vpTreeFillSearch.getResults());
|
|
||||||
System.out.println(vpTreeFillSearch.getDistances());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,46 +0,0 @@
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<configuration>
|
|
||||||
|
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
|
||||||
<file>logs/application.log</file>
|
|
||||||
<encoder>
|
|
||||||
<pattern> %logger{15} - %message%n%xException{5}
|
|
||||||
</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
|
||||||
<encoder>
|
|
||||||
<pattern> %logger{15} - %message%n%xException{5}
|
|
||||||
</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<logger name="org.deeplearning4j" level="INFO" />
|
|
||||||
<logger name="org.datavec" level="INFO" />
|
|
||||||
<logger name="org.nd4j" level="INFO" />
|
|
||||||
|
|
||||||
<root level="ERROR">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
<appender-ref ref="FILE" />
|
|
||||||
</root>
|
|
||||||
</configuration>
|
|
|
@ -1,60 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-nearestneighbors-client</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.mashape.unirest</groupId>
|
|
||||||
<artifactId>unirest-java</artifactId>
|
|
||||||
<version>${unirest.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,137 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.client;
|
|
||||||
|
|
||||||
import com.mashape.unirest.http.ObjectMapper;
|
|
||||||
import com.mashape.unirest.http.Unirest;
|
|
||||||
import com.mashape.unirest.request.HttpRequest;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import lombok.val;
|
|
||||||
import org.deeplearning4j.nearestneighbor.model.*;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.serde.base64.Nd4jBase64;
|
|
||||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class NearestNeighborsClient {
|
|
||||||
|
|
||||||
private String url;
|
|
||||||
@Setter
|
|
||||||
@Getter
|
|
||||||
protected String authToken;
|
|
||||||
|
|
||||||
public NearestNeighborsClient(String url){
|
|
||||||
this(url, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
static {
|
|
||||||
// Only one time
|
|
||||||
|
|
||||||
Unirest.setObjectMapper(new ObjectMapper() {
|
|
||||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
|
||||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
|
||||||
|
|
||||||
public <T> T readValue(String value, Class<T> valueType) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.readValue(value, valueType);
|
|
||||||
} catch (IOException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public String writeValue(Object value) {
|
|
||||||
try {
|
|
||||||
return jacksonObjectMapper.writeValueAsString(value);
|
|
||||||
} catch (JsonProcessingException e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Runs knn on the given index
|
|
||||||
* with the given k (note that this is for data
|
|
||||||
* already within the existing dataset not new data)
|
|
||||||
* @param index the index of the
|
|
||||||
* EXISTING ndarray
|
|
||||||
* to run a search on
|
|
||||||
* @param k the number of results
|
|
||||||
* @return
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
public NearestNeighborsResults knn(int index, int k) throws Exception {
|
|
||||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
|
||||||
request.setInputIndex(index);
|
|
||||||
request.setK(k);
|
|
||||||
val req = Unirest.post(url + "/knn");
|
|
||||||
req.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(request);
|
|
||||||
addAuthHeader(req);
|
|
||||||
|
|
||||||
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Run a k nearest neighbors search
|
|
||||||
* on a NEW data point
|
|
||||||
* @param k the number of results
|
|
||||||
* to retrieve
|
|
||||||
* @param arr the array to run the search on.
|
|
||||||
* Note that this must be a row vector
|
|
||||||
* @return
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception {
|
|
||||||
Base64NDArrayBody base64NDArrayBody =
|
|
||||||
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
|
|
||||||
|
|
||||||
val req = Unirest.post(url + "/knnnew");
|
|
||||||
req.header("accept", "application/json")
|
|
||||||
.header("Content-Type", "application/json").body(base64NDArrayBody);
|
|
||||||
addAuthHeader(req);
|
|
||||||
|
|
||||||
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add the specified authentication header to the specified HttpRequest
|
|
||||||
*
|
|
||||||
* @param request HTTP Request to add the authentication header to
|
|
||||||
*/
|
|
||||||
protected HttpRequest addAuthHeader(HttpRequest request) {
|
|
||||||
if (authToken != null) {
|
|
||||||
request.header("authorization", "Bearer " + authToken);
|
|
||||||
}
|
|
||||||
|
|
||||||
return request;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-nearestneighbors-model</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.projectlombok</groupId>
|
|
||||||
<artifactId>lombok</artifactId>
|
|
||||||
<version>${lombok.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Builder
|
|
||||||
public class Base64NDArrayBody implements Serializable {
|
|
||||||
private String ndarray;
|
|
||||||
private int k;
|
|
||||||
private boolean forceFillK;
|
|
||||||
}
|
|
|
@ -1,65 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class BatchRecord implements Serializable {
|
|
||||||
private List<CSVRecord> records;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a record
|
|
||||||
* @param record
|
|
||||||
*/
|
|
||||||
public void add(CSVRecord record) {
|
|
||||||
if (records == null)
|
|
||||||
records = new ArrayList<>();
|
|
||||||
records.add(record);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return a batch record based on a dataset
|
|
||||||
* @param dataSet the dataset to get the batch record for
|
|
||||||
* @return the batch record
|
|
||||||
*/
|
|
||||||
public static BatchRecord fromDataSet(DataSet dataSet) {
|
|
||||||
BatchRecord batchRecord = new BatchRecord();
|
|
||||||
for (int i = 0; i < dataSet.numExamples(); i++) {
|
|
||||||
batchRecord.add(CSVRecord.fromRow(dataSet.get(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
return batchRecord;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,85 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.linalg.dataset.DataSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class CSVRecord implements Serializable {
|
|
||||||
private String[] values;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Instantiate a csv record from a vector
|
|
||||||
* given either an input dataset and a
|
|
||||||
* one hot matrix, the index will be appended to
|
|
||||||
* the end of the record, or for regression
|
|
||||||
* it will append all values in the labels
|
|
||||||
* @param row the input vectors
|
|
||||||
* @return the record from this {@link DataSet}
|
|
||||||
*/
|
|
||||||
public static CSVRecord fromRow(DataSet row) {
|
|
||||||
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
|
|
||||||
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
|
|
||||||
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
|
|
||||||
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
|
|
||||||
//classification
|
|
||||||
CSVRecord record;
|
|
||||||
int idx = 0;
|
|
||||||
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
|
|
||||||
String[] values = new String[row.getFeatures().columns() + 1];
|
|
||||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
|
||||||
}
|
|
||||||
int maxIdx = 0;
|
|
||||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
|
||||||
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
|
|
||||||
maxIdx = i;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
values[idx++] = String.valueOf(maxIdx);
|
|
||||||
record = new CSVRecord(values);
|
|
||||||
}
|
|
||||||
//regression (any number of values)
|
|
||||||
else {
|
|
||||||
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
|
|
||||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
|
||||||
}
|
|
||||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
|
||||||
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
record = new CSVRecord(values);
|
|
||||||
|
|
||||||
}
|
|
||||||
return record;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,32 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class NearestNeighborRequest implements Serializable {
|
|
||||||
private int k;
|
|
||||||
private int inputIndex;
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor
|
|
||||||
@NoArgsConstructor
|
|
||||||
public class NearestNeighborsResult {
|
|
||||||
public NearestNeighborsResult(int index, double distance) {
|
|
||||||
this(index, distance, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
private int index;
|
|
||||||
private double distance;
|
|
||||||
private String label;
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.nearestneighbor.model;
|
|
||||||
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@Builder
|
|
||||||
@NoArgsConstructor
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class NearestNeighborsResults implements Serializable {
|
|
||||||
private List<NearestNeighborsResult> results;
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,103 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--
|
|
||||||
~ /* ******************************************************************************
|
|
||||||
~ *
|
|
||||||
~ *
|
|
||||||
~ * This program and the accompanying materials are made available under the
|
|
||||||
~ * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~ *
|
|
||||||
~ * See the NOTICE file distributed with this work for additional
|
|
||||||
~ * information regarding copyright ownership.
|
|
||||||
~ * Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ * License for the specific language governing permissions and limitations
|
|
||||||
~ * under the License.
|
|
||||||
~ *
|
|
||||||
~ * SPDX-License-Identifier: Apache-2.0
|
|
||||||
~ ******************************************************************************/
|
|
||||||
-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
|
||||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>nearestneighbor-core</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>nearestneighbor-core</name>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>ch.qos.logback</groupId>
|
|
||||||
<artifactId>logback-classic</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nn</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-datasets</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>joda-time</groupId>
|
|
||||||
<artifactId>joda-time</artifactId>
|
|
||||||
<version>2.10.3</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-native</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-11.0</id>
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,218 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.algorithm;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
|
||||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterUtils;
|
|
||||||
import org.deeplearning4j.clustering.cluster.Point;
|
|
||||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationInfo;
|
|
||||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
|
|
||||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
|
|
||||||
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
|
|
||||||
import org.deeplearning4j.clustering.util.MultiThreadUtils;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = 338231277453149972L;
|
|
||||||
|
|
||||||
private ClusteringStrategy clusteringStrategy;
|
|
||||||
private IterationHistory iterationHistory;
|
|
||||||
private int currentIteration = 0;
|
|
||||||
private ClusterSet clusterSet;
|
|
||||||
private List<Point> initialPoints;
|
|
||||||
private transient ExecutorService exec;
|
|
||||||
private boolean useKmeansPlusPlus;
|
|
||||||
|
|
||||||
|
|
||||||
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
|
||||||
this.clusteringStrategy = clusteringStrategy;
|
|
||||||
this.exec = MultiThreadUtils.newExecutorService();
|
|
||||||
this.useKmeansPlusPlus = useKmeansPlusPlus;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusteringStrategy
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
|
||||||
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param points
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public ClusterSet applyTo(List<Point> points) {
|
|
||||||
resetState(points);
|
|
||||||
initClusters(useKmeansPlusPlus);
|
|
||||||
iterations();
|
|
||||||
return clusterSet;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void resetState(List<Point> points) {
|
|
||||||
this.iterationHistory = new IterationHistory();
|
|
||||||
this.currentIteration = 0;
|
|
||||||
this.clusterSet = null;
|
|
||||||
this.initialPoints = points;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Run clustering iterations until a
|
|
||||||
* termination condition is hit.
|
|
||||||
* This is done by first classifying all points,
|
|
||||||
* and then updating cluster centers based on
|
|
||||||
* those classified points
|
|
||||||
*/
|
|
||||||
private void iterations() {
|
|
||||||
int iterationCount = 0;
|
|
||||||
while ((clusteringStrategy.getTerminationCondition() != null
|
|
||||||
&& !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory))
|
|
||||||
|| iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
|
|
||||||
currentIteration++;
|
|
||||||
removePoints();
|
|
||||||
classifyPoints();
|
|
||||||
applyClusteringStrategy();
|
|
||||||
log.trace("Completed clustering iteration {}", ++iterationCount);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void classifyPoints() {
|
|
||||||
//Classify points. This also adds each point to the ClusterSet
|
|
||||||
ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec);
|
|
||||||
//Update the cluster centers, based on the points within each cluster
|
|
||||||
ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec);
|
|
||||||
iterationHistory.getIterationsInfos().put(currentIteration,
|
|
||||||
new IterationInfo(currentIteration, clusterSetInfo));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize the
|
|
||||||
* cluster centers at random
|
|
||||||
*/
|
|
||||||
protected void initClusters(boolean kMeansPlusPlus) {
|
|
||||||
log.info("Generating initial clusters");
|
|
||||||
List<Point> points = new ArrayList<>(initialPoints);
|
|
||||||
|
|
||||||
//Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly)
|
|
||||||
val random = Nd4j.getRandom();
|
|
||||||
Distance distanceFn = clusteringStrategy.getDistanceFunction();
|
|
||||||
int initialClusterCount = clusteringStrategy.getInitialClusterCount();
|
|
||||||
clusterSet = new ClusterSet(distanceFn,
|
|
||||||
clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()});
|
|
||||||
clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size())));
|
|
||||||
|
|
||||||
|
|
||||||
//dxs: distances between
|
|
||||||
// each point and nearest cluster to that point
|
|
||||||
INDArray dxs = Nd4j.create(points.size());
|
|
||||||
dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE);
|
|
||||||
|
|
||||||
//Generate the initial cluster centers, by randomly selecting a point between 0 and max distance
|
|
||||||
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
|
|
||||||
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
|
|
||||||
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
|
|
||||||
double summed = Nd4j.sum(dxs).getDouble(0);
|
|
||||||
double r = kMeansPlusPlus ? random.nextDouble() * summed:
|
|
||||||
random.nextFloat() * dxs.maxNumber().doubleValue();
|
|
||||||
|
|
||||||
for (int i = 0; i < dxs.length(); i++) {
|
|
||||||
double distance = dxs.getDouble(i);
|
|
||||||
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
|
|
||||||
"function must return values >= 0, got distance %s for function s", distance, distanceFn);
|
|
||||||
if (dxs.getDouble(i) >= r) {
|
|
||||||
clusterSet.addNewClusterWithCenter(points.remove(i));
|
|
||||||
dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i));
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet);
|
|
||||||
iterationHistory.getIterationsInfos().put(currentIteration,
|
|
||||||
new IterationInfo(currentIteration, initialClusterSetInfo));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected void applyClusteringStrategy() {
|
|
||||||
if (!isStrategyApplicableNow())
|
|
||||||
return;
|
|
||||||
|
|
||||||
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
|
|
||||||
if (!clusteringStrategy.isAllowEmptyClusters()) {
|
|
||||||
int removedCount = removeEmptyClusters(clusterSetInfo);
|
|
||||||
if (removedCount > 0) {
|
|
||||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
|
|
||||||
|
|
||||||
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT)
|
|
||||||
&& clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) {
|
|
||||||
int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo,
|
|
||||||
clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec);
|
|
||||||
if (splitCount > 0)
|
|
||||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION))
|
|
||||||
optimize();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void optimize() {
|
|
||||||
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
|
|
||||||
OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy;
|
|
||||||
boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec);
|
|
||||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean isStrategyApplicableNow() {
|
|
||||||
return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0
|
|
||||||
&& clusteringStrategy.isOptimizationApplicableNow(iterationHistory);
|
|
||||||
}
|
|
||||||
|
|
||||||
protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
|
|
||||||
List<Cluster> removedClusters = clusterSet.removeEmptyClusters();
|
|
||||||
clusterSetInfo.removeClusterInfos(removedClusters);
|
|
||||||
return removedClusters.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected void removePoints() {
|
|
||||||
clusterSet.removePoints();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.algorithm;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
|
||||||
import org.deeplearning4j.clustering.cluster.Point;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public interface ClusteringAlgorithm {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Apply a clustering
|
|
||||||
* algorithm for a given result
|
|
||||||
* @param points
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
ClusterSet applyTo(List<Point> points);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.algorithm;
|
|
||||||
|
|
||||||
public enum Distance {
|
|
||||||
EUCLIDEAN("euclidean"),
|
|
||||||
COSINE_DISTANCE("cosinedistance"),
|
|
||||||
COSINE_SIMILARITY("cosinesimilarity"),
|
|
||||||
MANHATTAN("manhattan"),
|
|
||||||
DOT("dot"),
|
|
||||||
JACCARD("jaccard"),
|
|
||||||
HAMMING("hamming");
|
|
||||||
|
|
||||||
private String functionName;
|
|
||||||
private Distance(String name) {
|
|
||||||
functionName = name;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
return functionName;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
public class CentersHolder {
|
|
||||||
private INDArray centers;
|
|
||||||
private long index = 0;
|
|
||||||
|
|
||||||
protected transient ReduceOp op;
|
|
||||||
protected ArgMin imin;
|
|
||||||
protected transient INDArray distances;
|
|
||||||
protected transient INDArray argMin;
|
|
||||||
|
|
||||||
private long rows, cols;
|
|
||||||
|
|
||||||
public CentersHolder(long rows, long cols) {
|
|
||||||
this.rows = rows;
|
|
||||||
this.cols = cols;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getCenters() {
|
|
||||||
return this.centers;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized void addCenter(INDArray pointView) {
|
|
||||||
if (centers == null)
|
|
||||||
this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols});
|
|
||||||
|
|
||||||
centers.putRow(index++, pointView);
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized Pair<Double, Long> getCenterByMinDistance(Point point, Distance distanceFunction) {
|
|
||||||
if (distances == null)
|
|
||||||
distances = Nd4j.create(centers.dataType(), centers.rows());
|
|
||||||
|
|
||||||
if (argMin == null)
|
|
||||||
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
|
|
||||||
|
|
||||||
if (op == null) {
|
|
||||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
|
||||||
imin = new ArgMin(distances, argMin);
|
|
||||||
op.setZ(distances);
|
|
||||||
}
|
|
||||||
|
|
||||||
op.setY(point.getArray());
|
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
Nd4j.getExecutioner().exec(imin);
|
|
||||||
|
|
||||||
Pair<Double, Long> result = new Pair<>();
|
|
||||||
result.setFirst(distances.getDouble(argMin.getLong(0)));
|
|
||||||
result.setSecond(argMin.getLong(0));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) {
|
|
||||||
if (distances == null)
|
|
||||||
distances = Nd4j.create(centers.dataType(), centers.rows());
|
|
||||||
|
|
||||||
if (argMin == null)
|
|
||||||
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
|
|
||||||
|
|
||||||
if (op == null) {
|
|
||||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
|
||||||
imin = new ArgMin(distances, argMin);
|
|
||||||
op.setZ(distances);
|
|
||||||
}
|
|
||||||
|
|
||||||
op.setY(point.getArray());
|
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
Nd4j.getExecutioner().exec(imin);
|
|
||||||
|
|
||||||
System.out.println(distances);
|
|
||||||
return distances;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,150 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class Cluster implements Serializable {
|
|
||||||
|
|
||||||
private String id = UUID.randomUUID().toString();
|
|
||||||
private String label;
|
|
||||||
|
|
||||||
private Point center;
|
|
||||||
private List<Point> points = Collections.synchronizedList(new ArrayList<Point>());
|
|
||||||
private boolean inverse = false;
|
|
||||||
private Distance distanceFunction;
|
|
||||||
|
|
||||||
public Cluster() {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param center
|
|
||||||
* @param distanceFunction
|
|
||||||
*/
|
|
||||||
public Cluster(Point center, Distance distanceFunction) {
|
|
||||||
this(center, false, distanceFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param center
|
|
||||||
* @param distanceFunction
|
|
||||||
*/
|
|
||||||
public Cluster(Point center, boolean inverse, Distance distanceFunction) {
|
|
||||||
this.distanceFunction = distanceFunction;
|
|
||||||
this.inverse = inverse;
|
|
||||||
setCenter(center);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the distance to the given
|
|
||||||
* point from the cluster
|
|
||||||
* @param point the point to get the distance for
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public double getDistanceToCenter(Point point) {
|
|
||||||
return Nd4j.getExecutioner().execAndReturn(
|
|
||||||
ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a point to the cluster
|
|
||||||
* @param point
|
|
||||||
*/
|
|
||||||
public void addPoint(Point point) {
|
|
||||||
addPoint(point, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a point to the cluster
|
|
||||||
* @param point the point to add
|
|
||||||
* @param moveClusterCenter whether to update
|
|
||||||
* the cluster centroid or not
|
|
||||||
*/
|
|
||||||
public void addPoint(Point point, boolean moveClusterCenter) {
|
|
||||||
if (moveClusterCenter) {
|
|
||||||
if (isInverse()) {
|
|
||||||
center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1);
|
|
||||||
} else {
|
|
||||||
center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
getPoints().add(point);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Clear out the ponits
|
|
||||||
*/
|
|
||||||
public void removePoints() {
|
|
||||||
if (getPoints() != null)
|
|
||||||
getPoints().clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the cluster is empty or not
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean isEmpty() {
|
|
||||||
return points == null || points.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the point with the given id
|
|
||||||
* @param id
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Point getPoint(String id) {
|
|
||||||
for (Point point : points)
|
|
||||||
if (id.equals(point.getId()))
|
|
||||||
return point;
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Remove the point and return it
|
|
||||||
* @param id
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Point removePoint(String id) {
|
|
||||||
Point removePoint = null;
|
|
||||||
for (Point point : points)
|
|
||||||
if (id.equals(point.getId()))
|
|
||||||
removePoint = point;
|
|
||||||
if (removePoint != null)
|
|
||||||
points.remove(removePoint);
|
|
||||||
return removePoint;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,259 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class ClusterSet implements Serializable {
|
|
||||||
|
|
||||||
private Distance distanceFunction;
|
|
||||||
private List<Cluster> clusters;
|
|
||||||
private CentersHolder centersHolder;
|
|
||||||
private Map<String, String> pointDistribution;
|
|
||||||
private boolean inverse;
|
|
||||||
|
|
||||||
public ClusterSet(boolean inverse) {
|
|
||||||
this(null, inverse, null);
|
|
||||||
}
|
|
||||||
|
|
||||||
public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) {
|
|
||||||
this.distanceFunction = distanceFunction;
|
|
||||||
this.inverse = inverse;
|
|
||||||
this.clusters = Collections.synchronizedList(new ArrayList<Cluster>());
|
|
||||||
this.pointDistribution = Collections.synchronizedMap(new HashMap<String, String>());
|
|
||||||
if (shape != null)
|
|
||||||
this.centersHolder = new CentersHolder(shape[0], shape[1]);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public boolean isInverse() {
|
|
||||||
return inverse;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param center
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Cluster addNewClusterWithCenter(Point center) {
|
|
||||||
Cluster newCluster = new Cluster(center, distanceFunction);
|
|
||||||
getClusters().add(newCluster);
|
|
||||||
setPointLocation(center, newCluster);
|
|
||||||
centersHolder.addCenter(center.getArray());
|
|
||||||
return newCluster;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public PointClassification classifyPoint(Point point) {
|
|
||||||
return classifyPoint(point, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param points
|
|
||||||
*/
|
|
||||||
public void classifyPoints(List<Point> points) {
|
|
||||||
classifyPoints(points, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param points
|
|
||||||
* @param moveClusterCenter
|
|
||||||
*/
|
|
||||||
public void classifyPoints(List<Point> points, boolean moveClusterCenter) {
|
|
||||||
for (Point point : points)
|
|
||||||
classifyPoint(point, moveClusterCenter);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
* @param moveClusterCenter
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public PointClassification classifyPoint(Point point, boolean moveClusterCenter) {
|
|
||||||
Pair<Cluster, Double> nearestCluster = nearestCluster(point);
|
|
||||||
Cluster newCluster = nearestCluster.getKey();
|
|
||||||
boolean locationChange = isPointLocationChange(point, newCluster);
|
|
||||||
addPointToCluster(point, newCluster, moveClusterCenter);
|
|
||||||
return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange);
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean isPointLocationChange(Point point, Cluster newCluster) {
|
|
||||||
if (!getPointDistribution().containsKey(point.getId()))
|
|
||||||
return true;
|
|
||||||
return !getPointDistribution().get(point.getId()).equals(newCluster.getId());
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) {
|
|
||||||
cluster.addPoint(point, moveClusterCenter);
|
|
||||||
setPointLocation(point, cluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void setPointLocation(Point point, Cluster cluster) {
|
|
||||||
pointDistribution.put(point.getId(), cluster.getId());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Pair<Cluster, Double> nearestCluster(Point point) {
|
|
||||||
|
|
||||||
/*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE;
|
|
||||||
|
|
||||||
double currentDistance;
|
|
||||||
for (Cluster cluster : getClusters()) {
|
|
||||||
currentDistance = cluster.getDistanceToCenter(point);
|
|
||||||
if (isInverse()) {
|
|
||||||
if (currentDistance > minDistance) {
|
|
||||||
minDistance = currentDistance;
|
|
||||||
nearestCluster = cluster;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (currentDistance < minDistance) {
|
|
||||||
minDistance = currentDistance;
|
|
||||||
nearestCluster = cluster;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}*/
|
|
||||||
|
|
||||||
Pair<Double, Long> nearestCenterData = centersHolder.
|
|
||||||
getCenterByMinDistance(point, distanceFunction);
|
|
||||||
Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue());
|
|
||||||
double minDistance = nearestCenterData.getFirst();
|
|
||||||
return Pair.of(nearestCluster, minDistance);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param m1
|
|
||||||
* @param m2
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public double getDistance(Point m1, Point m2) {
|
|
||||||
return Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param point
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
/*public double getDistanceFromNearestCluster(Point point) {
|
|
||||||
return nearestCluster(point).getValue();
|
|
||||||
}*/
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterId
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public String getClusterCenterId(String clusterId) {
|
|
||||||
Point clusterCenter = getClusterCenter(clusterId);
|
|
||||||
return clusterCenter == null ? null : clusterCenter.getId();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterId
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Point getClusterCenter(String clusterId) {
|
|
||||||
Cluster cluster = getCluster(clusterId);
|
|
||||||
return cluster == null ? null : cluster.getCenter();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param id
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Cluster getCluster(String id) {
|
|
||||||
for (int i = 0, j = clusters.size(); i < j; i++)
|
|
||||||
if (id.equals(clusters.get(i).getId()))
|
|
||||||
return clusters.get(i);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public int getClusterCount() {
|
|
||||||
return getClusters() == null ? 0 : getClusters().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public void removePoints() {
|
|
||||||
for (Cluster cluster : getClusters())
|
|
||||||
cluster.removePoints();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param count
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public List<Cluster> getMostPopulatedClusters(int count) {
|
|
||||||
List<Cluster> mostPopulated = new ArrayList<>(clusters);
|
|
||||||
Collections.sort(mostPopulated, new Comparator<Cluster>() {
|
|
||||||
public int compare(Cluster o1, Cluster o2) {
|
|
||||||
return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
return mostPopulated.subList(0, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public List<Cluster> removeEmptyClusters() {
|
|
||||||
List<Cluster> emptyClusters = new ArrayList<>();
|
|
||||||
for (Cluster cluster : clusters)
|
|
||||||
if (cluster.isEmpty())
|
|
||||||
emptyClusters.add(cluster);
|
|
||||||
clusters.removeAll(emptyClusters);
|
|
||||||
return emptyClusters;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,531 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import lombok.val;
|
|
||||||
import org.apache.commons.lang3.ArrayUtils;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.info.ClusterInfo;
|
|
||||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
|
||||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
|
|
||||||
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
|
|
||||||
import org.deeplearning4j.clustering.util.MathUtils;
|
|
||||||
import org.deeplearning4j.clustering.util.MultiThreadUtils;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
|
||||||
@Slf4j
|
|
||||||
public class ClusterUtils {
|
|
||||||
|
|
||||||
/** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */
|
|
||||||
public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points,
|
|
||||||
ExecutorService executorService) {
|
|
||||||
final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
|
|
||||||
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
for (final Point point : points) {
|
|
||||||
//tasks.add(new Runnable() {
|
|
||||||
// public void run() {
|
|
||||||
try {
|
|
||||||
PointClassification result = classifyPoint(clusterSet, point);
|
|
||||||
if (result.isNewLocation())
|
|
||||||
clusterSetInfo.getPointLocationChange().incrementAndGet();
|
|
||||||
clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter()
|
|
||||||
.put(point.getId(), result.getDistanceFromCenter());
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error classifying point", t);
|
|
||||||
}
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
return clusterSetInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
|
|
||||||
return clusterSet.classifyPoint(point, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
|
||||||
ExecutorService executorService) {
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
int nClusters = clusterSet.getClusterCount();
|
|
||||||
for (int i = 0; i < nClusters; i++) {
|
|
||||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
|
||||||
//tasks.add(new Runnable() {
|
|
||||||
// public void run() {
|
|
||||||
try {
|
|
||||||
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
|
||||||
refreshClusterCenter(cluster, clusterInfo);
|
|
||||||
deriveClusterInfoDistanceStatistics(clusterInfo);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error refreshing cluster centers", t);
|
|
||||||
}
|
|
||||||
// }
|
|
||||||
//});
|
|
||||||
}
|
|
||||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
|
|
||||||
int pointsCount = cluster.getPoints().size();
|
|
||||||
if (pointsCount == 0)
|
|
||||||
return;
|
|
||||||
Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length()));
|
|
||||||
for (Point point : cluster.getPoints()) {
|
|
||||||
INDArray arr = point.getArray();
|
|
||||||
if (cluster.isInverse())
|
|
||||||
center.getArray().subi(arr);
|
|
||||||
else
|
|
||||||
center.getArray().addi(arr);
|
|
||||||
}
|
|
||||||
center.getArray().divi(pointsCount);
|
|
||||||
cluster.setCenter(center);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param info
|
|
||||||
*/
|
|
||||||
public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
|
|
||||||
int pointCount = info.getPointDistancesFromCenter().size();
|
|
||||||
if (pointCount == 0)
|
|
||||||
return;
|
|
||||||
|
|
||||||
double[] distances =
|
|
||||||
ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {}));
|
|
||||||
double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances);
|
|
||||||
double total = MathUtils.sum(distances);
|
|
||||||
info.setMaxPointDistanceFromCenter(max);
|
|
||||||
info.setTotalPointDistanceFromCenter(total);
|
|
||||||
info.setAveragePointDistanceFromCenter(total / pointCount);
|
|
||||||
info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param points
|
|
||||||
* @param previousDxs
|
|
||||||
* @param executorService
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet,
|
|
||||||
final List<Point> points, INDArray previousDxs, ExecutorService executorService) {
|
|
||||||
final int pointsCount = points.size();
|
|
||||||
final INDArray dxs = Nd4j.create(pointsCount);
|
|
||||||
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
|
||||||
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
for (int i = 0; i < pointsCount; i++) {
|
|
||||||
final int i2 = i;
|
|
||||||
//tasks.add(new Runnable() {
|
|
||||||
// public void run() {
|
|
||||||
try {
|
|
||||||
Point point = points.get(i2);
|
|
||||||
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
|
|
||||||
: Math.pow(newCluster.getDistanceToCenter(point), 2);
|
|
||||||
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error computing squared distance from nearest cluster", t);
|
|
||||||
}
|
|
||||||
// }
|
|
||||||
//});
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
for (int i = 0; i < pointsCount; i++) {
|
|
||||||
double previousMinDistance = previousDxs.getDouble(i);
|
|
||||||
if (clusterSet.isInverse()) {
|
|
||||||
if (dxs.getDouble(i) < previousMinDistance) {
|
|
||||||
|
|
||||||
dxs.putScalar(i, previousMinDistance);
|
|
||||||
}
|
|
||||||
} else if (dxs.getDouble(i) > previousMinDistance)
|
|
||||||
dxs.putScalar(i, previousMinDistance);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dxs;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
|
|
||||||
final List<Point> points, INDArray previousDxs) {
|
|
||||||
final int pointsCount = points.size();
|
|
||||||
final INDArray dxs = Nd4j.create(pointsCount);
|
|
||||||
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
|
||||||
|
|
||||||
Double sum = new Double(0);
|
|
||||||
for (int i = 0; i < pointsCount; i++) {
|
|
||||||
|
|
||||||
Point point = points.get(i);
|
|
||||||
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
|
|
||||||
sum += dist;
|
|
||||||
dxs.putScalar(i, sum);
|
|
||||||
}
|
|
||||||
|
|
||||||
return dxs;
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
|
|
||||||
ExecutorService executor = MultiThreadUtils.newExecutorService();
|
|
||||||
ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor);
|
|
||||||
executor.shutdownNow();
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
|
|
||||||
final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
|
|
||||||
int clusterCount = clusterSet.getClusterCount();
|
|
||||||
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
for (int i = 0; i < clusterCount; i++) {
|
|
||||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
|
||||||
//tasks.add(new Runnable() {
|
|
||||||
// public void run() {
|
|
||||||
try {
|
|
||||||
info.getClustersInfos().put(cluster.getId(),
|
|
||||||
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error computing cluster set info", t);
|
|
||||||
}
|
|
||||||
//}
|
|
||||||
//});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
|
|
||||||
//tasks = new ArrayList<>();
|
|
||||||
for (int i = 0; i < clusterCount; i++) {
|
|
||||||
final int clusterIdx = i;
|
|
||||||
final Cluster fromCluster = clusterSet.getClusters().get(i);
|
|
||||||
//tasks.add(new Runnable() {
|
|
||||||
//public void run() {
|
|
||||||
try {
|
|
||||||
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
|
|
||||||
Cluster toCluster = clusterSet.getClusters().get(k);
|
|
||||||
double distance = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(
|
|
||||||
clusterSet.getDistanceFunction(),
|
|
||||||
fromCluster.getCenter().getArray(),
|
|
||||||
toCluster.getCenter().getArray()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(),
|
|
||||||
distance);
|
|
||||||
}
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error computing distances", t);
|
|
||||||
}
|
|
||||||
// }
|
|
||||||
//});
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param cluster
|
|
||||||
* @param distanceFunction
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) {
|
|
||||||
ClusterInfo info = new ClusterInfo(cluster.isInverse(), true);
|
|
||||||
for (int i = 0, j = cluster.getPoints().size(); i < j; i++) {
|
|
||||||
Point point = cluster.getPoints().get(i);
|
|
||||||
//shouldn't need to inverse here. other parts of
|
|
||||||
//the code should interpret the "distance" or score here
|
|
||||||
double distance = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction,
|
|
||||||
cluster.getCenter().getArray(), point.getArray()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
info.getPointDistancesFromCenter().put(point.getId(), distance);
|
|
||||||
double diff = info.getTotalPointDistanceFromCenter() + distance;
|
|
||||||
info.setTotalPointDistanceFromCenter(diff);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!cluster.getPoints().isEmpty())
|
|
||||||
info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size());
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param optimization
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param executor
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet,
|
|
||||||
ClusterSetInfo clusterSetInfo, ExecutorService executor) {
|
|
||||||
|
|
||||||
if (optimization.isClusteringOptimizationType(
|
|
||||||
ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
|
|
||||||
int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet,
|
|
||||||
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
|
|
||||||
return splitCount > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (optimization.isClusteringOptimizationType(
|
|
||||||
ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
|
|
||||||
int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet,
|
|
||||||
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
|
|
||||||
return splitCount > 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param info
|
|
||||||
* @param count
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info,
|
|
||||||
int count) {
|
|
||||||
List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters());
|
|
||||||
Collections.sort(clusters, new Comparator<Cluster>() {
|
|
||||||
public int compare(Cluster o1, Cluster o2) {
|
|
||||||
Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
|
|
||||||
Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
|
|
||||||
int comp = o1TotalDistance.compareTo(o2TotalDistance);
|
|
||||||
return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return clusters.subList(0, count);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param info
|
|
||||||
* @param maximumAverageDistance
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
|
|
||||||
final ClusterSetInfo info, double maximumAverageDistance) {
|
|
||||||
List<Cluster> clusters = new ArrayList<>();
|
|
||||||
for (Cluster cluster : clusterSet.getClusters()) {
|
|
||||||
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
|
|
||||||
if (clusterInfo != null) {
|
|
||||||
//distances
|
|
||||||
if (clusterInfo.isInverse()) {
|
|
||||||
if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance)
|
|
||||||
clusters.add(cluster);
|
|
||||||
} else {
|
|
||||||
if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)
|
|
||||||
clusters.add(cluster);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return clusters;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param info
|
|
||||||
* @param maximumDistance
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
|
|
||||||
final ClusterSetInfo info, double maximumDistance) {
|
|
||||||
List<Cluster> clusters = new ArrayList<>();
|
|
||||||
for (Cluster cluster : clusterSet.getClusters()) {
|
|
||||||
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
|
|
||||||
if (clusterInfo != null) {
|
|
||||||
if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) {
|
|
||||||
clusters.add(cluster);
|
|
||||||
} else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) {
|
|
||||||
clusters.add(cluster);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return clusters;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param count
|
|
||||||
* @param executorService
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
|
|
||||||
ExecutorService executorService) {
|
|
||||||
List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
|
|
||||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
|
|
||||||
return clustersToSplit.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param maxWithinClusterDistance
|
|
||||||
* @param executorService
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet,
|
|
||||||
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
|
|
||||||
List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
|
|
||||||
maxWithinClusterDistance);
|
|
||||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
|
|
||||||
return clustersToSplit.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param maxWithinClusterDistance
|
|
||||||
* @param executorService
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet,
|
|
||||||
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
|
|
||||||
List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
|
|
||||||
maxWithinClusterDistance);
|
|
||||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
|
|
||||||
return clustersToSplit.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param count
|
|
||||||
* @param executorService
|
|
||||||
*/
|
|
||||||
public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
|
|
||||||
ExecutorService executorService) {
|
|
||||||
List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
|
|
||||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param clusters
|
|
||||||
* @param maxDistance
|
|
||||||
* @param executorService
|
|
||||||
*/
|
|
||||||
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
|
||||||
List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
|
|
||||||
final Random random = new Random();
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
for (final Cluster cluster : clusters) {
|
|
||||||
tasks.add(new Runnable() {
|
|
||||||
public void run() {
|
|
||||||
try {
|
|
||||||
ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
|
||||||
List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
|
|
||||||
int rank = Math.min(fartherPoints.size(), 3);
|
|
||||||
String pointId = fartherPoints.get(random.nextInt(rank));
|
|
||||||
Point point = cluster.removePoint(pointId);
|
|
||||||
clusterSet.addNewClusterWithCenter(point);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error splitting clusters", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param clusterSetInfo
|
|
||||||
* @param clusters
|
|
||||||
* @param executorService
|
|
||||||
*/
|
|
||||||
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
|
||||||
List<Cluster> clusters, ExecutorService executorService) {
|
|
||||||
final Random random = new Random();
|
|
||||||
List<Runnable> tasks = new ArrayList<>();
|
|
||||||
for (final Cluster cluster : clusters) {
|
|
||||||
tasks.add(new Runnable() {
|
|
||||||
public void run() {
|
|
||||||
try {
|
|
||||||
Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
|
|
||||||
clusterSet.addNewClusterWithCenter(point);
|
|
||||||
} catch (Throwable t) {
|
|
||||||
log.warn("Error Splitting clusters (2)", t);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){
|
|
||||||
val op = createDistanceFunctionOp(distanceFunction, x, y);
|
|
||||||
op.setDimensions(dimensions);
|
|
||||||
return op;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){
|
|
||||||
switch (distanceFunction){
|
|
||||||
case COSINE_DISTANCE:
|
|
||||||
return new CosineDistance(x,y);
|
|
||||||
case COSINE_SIMILARITY:
|
|
||||||
return new CosineSimilarity(x,y);
|
|
||||||
case DOT:
|
|
||||||
return new Dot(x,y);
|
|
||||||
case EUCLIDEAN:
|
|
||||||
return new EuclideanDistance(x,y);
|
|
||||||
case JACCARD:
|
|
||||||
return new JaccardDistance(x,y);
|
|
||||||
case MANHATTAN:
|
|
||||||
return new ManhattanDistance(x,y);
|
|
||||||
default:
|
|
||||||
throw new IllegalStateException("Unknown distance function: " + distanceFunction);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,107 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.UUID;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class Point implements Serializable {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = -6658028541426027226L;
|
|
||||||
|
|
||||||
private String id = UUID.randomUUID().toString();
|
|
||||||
private String label;
|
|
||||||
private INDArray array;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param array
|
|
||||||
*/
|
|
||||||
public Point(INDArray array) {
|
|
||||||
super();
|
|
||||||
this.array = array;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param id
|
|
||||||
* @param array
|
|
||||||
*/
|
|
||||||
public Point(String id, INDArray array) {
|
|
||||||
super();
|
|
||||||
this.id = id;
|
|
||||||
this.array = array;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Point(String id, String label, double[] data) {
|
|
||||||
this(id, label, Nd4j.create(data));
|
|
||||||
}
|
|
||||||
|
|
||||||
public Point(String id, String label, INDArray array) {
|
|
||||||
super();
|
|
||||||
this.id = id;
|
|
||||||
this.label = label;
|
|
||||||
this.array = array;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param matrix
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static List<Point> toPoints(INDArray matrix) {
|
|
||||||
List<Point> arr = new ArrayList<>(matrix.rows());
|
|
||||||
for (int i = 0; i < matrix.rows(); i++) {
|
|
||||||
arr.add(new Point(matrix.getRow(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param vectors
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static List<Point> toPoints(List<INDArray> vectors) {
|
|
||||||
List<Point> points = new ArrayList<>();
|
|
||||||
for (INDArray vector : vectors)
|
|
||||||
points.add(new Point(vector));
|
|
||||||
return points;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,40 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class PointClassification implements Serializable {
|
|
||||||
|
|
||||||
private Cluster cluster;
|
|
||||||
private double distanceFromCenter;
|
|
||||||
private boolean newLocation;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.condition;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public interface ClusteringAlgorithmCondition {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationHistory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isSatisfied(IterationHistory iterationHistory);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,69 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.condition;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.LessThan;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
@AllArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable {
|
|
||||||
|
|
||||||
private Condition convergenceCondition;
|
|
||||||
private double pointsDistributionChangeRate;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param pointsDistributionChangeRate
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) {
|
|
||||||
Condition condition = new LessThan(pointsDistributionChangeRate);
|
|
||||||
return new ConvergenceCondition(condition, pointsDistributionChangeRate);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationHistory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
|
||||||
int iterationCount = iterationHistory.getIterationCount();
|
|
||||||
if (iterationCount <= 1)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get();
|
|
||||||
variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount();
|
|
||||||
|
|
||||||
return convergenceCondition.apply(variation);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.condition;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable {
|
|
||||||
|
|
||||||
private Condition iterationCountCondition;
|
|
||||||
|
|
||||||
protected FixedIterationCountCondition(int initialClusterCount) {
|
|
||||||
iterationCountCondition = new GreaterThanOrEqual(initialClusterCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationCount
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) {
|
|
||||||
return new FixedIterationCountCondition(iterationCount);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationHistory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
|
||||||
return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.condition;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.LessThan;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable {
|
|
||||||
|
|
||||||
private Condition varianceVariationCondition;
|
|
||||||
private int period;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param varianceVariation
|
|
||||||
* @param period
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) {
|
|
||||||
Condition condition = new LessThan(varianceVariation);
|
|
||||||
return new VarianceVariationCondition(condition, period);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationHistory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
|
||||||
if (iterationHistory.getIterationCount() <= period)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) {
|
|
||||||
double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo()
|
|
||||||
.getPointDistanceFromClusterVariance();
|
|
||||||
variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
|
|
||||||
.getPointDistanceFromClusterVariance();
|
|
||||||
variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
|
|
||||||
.getPointDistanceFromClusterVariance();
|
|
||||||
|
|
||||||
if (!varianceVariationCondition.apply(variation))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,114 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.info;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.ConcurrentHashMap;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class ClusterInfo implements Serializable {
|
|
||||||
|
|
||||||
private double averagePointDistanceFromCenter;
|
|
||||||
private double maxPointDistanceFromCenter;
|
|
||||||
private double pointDistanceFromCenterVariance;
|
|
||||||
private double totalPointDistanceFromCenter;
|
|
||||||
private boolean inverse;
|
|
||||||
private Map<String, Double> pointDistancesFromCenter = new ConcurrentHashMap<>();
|
|
||||||
|
|
||||||
public ClusterInfo(boolean inverse) {
|
|
||||||
this(false, inverse);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param threadSafe
|
|
||||||
*/
|
|
||||||
public ClusterInfo(boolean threadSafe, boolean inverse) {
|
|
||||||
super();
|
|
||||||
this.inverse = inverse;
|
|
||||||
if (threadSafe) {
|
|
||||||
pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Set<Map.Entry<String, Double>> getSortedPointDistancesFromCenter() {
|
|
||||||
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
|
|
||||||
@Override
|
|
||||||
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
|
|
||||||
int res = e1.getValue().compareTo(e2.getValue());
|
|
||||||
return res != 0 ? res : 1;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
|
|
||||||
return sortedEntries;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Set<Map.Entry<String, Double>> getReverseSortedPointDistancesFromCenter() {
|
|
||||||
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
|
|
||||||
@Override
|
|
||||||
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
|
|
||||||
int res = e1.getValue().compareTo(e2.getValue());
|
|
||||||
return -(res != 0 ? res : 1);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
|
|
||||||
return sortedEntries;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param maxDistance
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public List<String> getPointsFartherFromCenterThan(double maxDistance) {
|
|
||||||
Set<Map.Entry<String, Double>> sorted = getReverseSortedPointDistancesFromCenter();
|
|
||||||
List<String> ids = new ArrayList<>();
|
|
||||||
for (Map.Entry<String, Double> entry : sorted) {
|
|
||||||
if (inverse && entry.getValue() < -maxDistance) {
|
|
||||||
if (entry.getValue() < -maxDistance)
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (entry.getValue() > maxDistance)
|
|
||||||
break;
|
|
||||||
|
|
||||||
ids.add(entry.getKey());
|
|
||||||
}
|
|
||||||
return ids;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,142 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.info;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
|
||||||
import org.nd4j.shade.guava.collect.Table;
|
|
||||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
|
||||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
public class ClusterSetInfo implements Serializable {
|
|
||||||
|
|
||||||
private Map<String, ClusterInfo> clustersInfos = new HashMap<>();
|
|
||||||
private Table<String, String, Double> distancesBetweenClustersCenters = HashBasedTable.create();
|
|
||||||
private AtomicInteger pointLocationChange;
|
|
||||||
private boolean threadSafe;
|
|
||||||
private boolean inverse;
|
|
||||||
|
|
||||||
public ClusterSetInfo(boolean inverse) {
|
|
||||||
this(inverse, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param inverse
|
|
||||||
* @param threadSafe
|
|
||||||
*/
|
|
||||||
public ClusterSetInfo(boolean inverse, boolean threadSafe) {
|
|
||||||
this.pointLocationChange = new AtomicInteger(0);
|
|
||||||
this.threadSafe = threadSafe;
|
|
||||||
this.inverse = inverse;
|
|
||||||
if (threadSafe) {
|
|
||||||
clustersInfos = Collections.synchronizedMap(clustersInfos);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterSet
|
|
||||||
* @param threadSafe
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) {
|
|
||||||
ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe);
|
|
||||||
for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++)
|
|
||||||
info.addClusterInfo(clusterSet.getClusters().get(i).getId());
|
|
||||||
return info;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void removeClusterInfos(List<Cluster> clusters) {
|
|
||||||
for (Cluster cluster : clusters) {
|
|
||||||
clustersInfos.remove(cluster.getId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public ClusterInfo addClusterInfo(String clusterId) {
|
|
||||||
ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe);
|
|
||||||
clustersInfos.put(clusterId, clusterInfo);
|
|
||||||
return clusterInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
public ClusterInfo getClusterInfo(String clusterId) {
|
|
||||||
return clustersInfos.get(clusterId);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getAveragePointDistanceFromClusterCenter() {
|
|
||||||
if (clustersInfos == null || clustersInfos.isEmpty())
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
double average = 0;
|
|
||||||
for (ClusterInfo info : clustersInfos.values())
|
|
||||||
average += info.getAveragePointDistanceFromCenter();
|
|
||||||
return average / clustersInfos.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getPointDistanceFromClusterVariance() {
|
|
||||||
if (clustersInfos == null || clustersInfos.isEmpty())
|
|
||||||
return 0;
|
|
||||||
|
|
||||||
double average = 0;
|
|
||||||
for (ClusterInfo info : clustersInfos.values())
|
|
||||||
average += info.getPointDistanceFromCenterVariance();
|
|
||||||
return average / clustersInfos.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getPointsCount() {
|
|
||||||
int count = 0;
|
|
||||||
for (ClusterInfo clusterInfo : clustersInfos.values())
|
|
||||||
count += clusterInfo.getPointDistancesFromCenter().size();
|
|
||||||
return count;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Map<String, ClusterInfo> getClustersInfos() {
|
|
||||||
return clustersInfos;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setClustersInfos(Map<String, ClusterInfo> clustersInfos) {
|
|
||||||
this.clustersInfos = clustersInfos;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Table<String, String, Double> getDistancesBetweenClustersCenters() {
|
|
||||||
return distancesBetweenClustersCenters;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setDistancesBetweenClustersCenters(Table<String, String, Double> interClusterDistances) {
|
|
||||||
this.distancesBetweenClustersCenters = interClusterDistances;
|
|
||||||
}
|
|
||||||
|
|
||||||
public AtomicInteger getPointLocationChange() {
|
|
||||||
return pointLocationChange;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setPointLocationChange(AtomicInteger pointLocationChange) {
|
|
||||||
this.pointLocationChange = pointLocationChange;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.iteration;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.Setter;
|
|
||||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.HashMap;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
public class IterationHistory implements Serializable {
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private Map<Integer, IterationInfo> iterationsInfos = new HashMap<>();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public ClusterSetInfo getMostRecentClusterSetInfo() {
|
|
||||||
IterationInfo iterationInfo = getMostRecentIterationInfo();
|
|
||||||
return iterationInfo == null ? null : iterationInfo.getClusterSetInfo();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public IterationInfo getMostRecentIterationInfo() {
|
|
||||||
return getIterationInfo(getIterationCount() - 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public int getIterationCount() {
|
|
||||||
return getIterationsInfos().size();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationIdx
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public IterationInfo getIterationInfo(int iterationIdx) {
|
|
||||||
return getIterationsInfos().get(iterationIdx);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,49 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.iteration;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class IterationInfo implements Serializable {
|
|
||||||
|
|
||||||
private int index;
|
|
||||||
private ClusterSetInfo clusterSetInfo;
|
|
||||||
private boolean strategyApplied;
|
|
||||||
|
|
||||||
public IterationInfo(int index) {
|
|
||||||
super();
|
|
||||||
this.index = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public IterationInfo(int index, ClusterSetInfo clusterSetInfo) {
|
|
||||||
super();
|
|
||||||
this.index = index;
|
|
||||||
this.clusterSetInfo = clusterSetInfo;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,142 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kdtree;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
public class HyperRect implements Serializable {
|
|
||||||
|
|
||||||
//private List<Interval> points;
|
|
||||||
private float[] lowerEnds;
|
|
||||||
private float[] higherEnds;
|
|
||||||
private INDArray lowerEndsIND;
|
|
||||||
private INDArray higherEndsIND;
|
|
||||||
|
|
||||||
public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
|
|
||||||
this.lowerEnds = new float[lowerEndsIn.length];
|
|
||||||
this.higherEnds = new float[lowerEndsIn.length];
|
|
||||||
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
|
|
||||||
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
|
|
||||||
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
|
|
||||||
higherEndsIND = Nd4j.createFromArray(higherEnds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public HyperRect(float[] point) {
|
|
||||||
this(point, point);
|
|
||||||
}
|
|
||||||
|
|
||||||
public HyperRect(Pair<float[], float[]> ends) {
|
|
||||||
this(ends.getFirst(), ends.getSecond());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void enlargeTo(INDArray point) {
|
|
||||||
float[] pointAsArray = point.toFloatVector();
|
|
||||||
for (int i = 0; i < lowerEnds.length; i++) {
|
|
||||||
float p = pointAsArray[i];
|
|
||||||
if (lowerEnds[i] > p)
|
|
||||||
lowerEnds[i] = p;
|
|
||||||
else if (higherEnds[i] < p)
|
|
||||||
higherEnds[i] = p;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public static Pair<float[],float[]> point(INDArray vector) {
|
|
||||||
Pair<float[],float[]> ret = new Pair<>();
|
|
||||||
float[] curr = new float[(int)vector.length()];
|
|
||||||
for (int i = 0; i < vector.length(); i++) {
|
|
||||||
curr[i] = vector.getFloat(i);
|
|
||||||
}
|
|
||||||
ret.setFirst(curr);
|
|
||||||
ret.setSecond(curr);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/*public List<Boolean> contains(INDArray hPoint) {
|
|
||||||
List<Boolean> ret = new ArrayList<>();
|
|
||||||
for (int i = 0; i < hPoint.length(); i++) {
|
|
||||||
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
|
|
||||||
higherEnds[i] >= hPoint.getDouble(i));
|
|
||||||
}
|
|
||||||
return ret;
|
|
||||||
}*/
|
|
||||||
|
|
||||||
public double minDistance(INDArray hPoint, INDArray output) {
|
|
||||||
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
|
|
||||||
return output.getFloat(0);
|
|
||||||
|
|
||||||
/*double ret = 0.0;
|
|
||||||
double[] pointAsArray = hPoint.toDoubleVector();
|
|
||||||
for (int i = 0; i < pointAsArray.length; i++) {
|
|
||||||
double p = pointAsArray[i];
|
|
||||||
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
|
|
||||||
if (p < lowerEnds[i])
|
|
||||||
ret += Math.pow((p - lowerEnds[i]), 2);
|
|
||||||
else
|
|
||||||
ret += Math.pow((p - higherEnds[i]), 2);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ret = Math.pow(ret, 0.5);
|
|
||||||
return ret;*/
|
|
||||||
}
|
|
||||||
|
|
||||||
public HyperRect getUpper(INDArray hPoint, int desc) {
|
|
||||||
//Interval interval = points.get(desc);
|
|
||||||
float higher = higherEnds[desc];
|
|
||||||
float d = hPoint.getFloat(desc);
|
|
||||||
if (higher < d)
|
|
||||||
return null;
|
|
||||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
|
||||||
if (ret.lowerEnds[desc] < d)
|
|
||||||
ret.lowerEnds[desc] = d;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public HyperRect getLower(INDArray hPoint, int desc) {
|
|
||||||
//Interval interval = points.get(desc);
|
|
||||||
float lower = lowerEnds[desc];
|
|
||||||
float d = hPoint.getFloat(desc);
|
|
||||||
if (lower > d)
|
|
||||||
return null;
|
|
||||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
|
||||||
//Interval i2 = ret.points.get(desc);
|
|
||||||
if (ret.higherEnds[desc] > d)
|
|
||||||
ret.higherEnds[desc] = d;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String toString() {
|
|
||||||
String retVal = "";
|
|
||||||
retVal += "[";
|
|
||||||
for (int i = 0; i < lowerEnds.length; ++i) {
|
|
||||||
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
|
|
||||||
}
|
|
||||||
retVal += "]";
|
|
||||||
return retVal;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,370 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kdtree;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.Comparator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class KDTree implements Serializable {
|
|
||||||
|
|
||||||
private KDNode root;
|
|
||||||
private int dims = 100;
|
|
||||||
public final static int GREATER = 1;
|
|
||||||
public final static int LESS = 0;
|
|
||||||
private int size = 0;
|
|
||||||
private HyperRect rect;
|
|
||||||
|
|
||||||
public KDTree(int dims) {
|
|
||||||
this.dims = dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Insert a point in to the tree
|
|
||||||
* @param point the point to insert
|
|
||||||
*/
|
|
||||||
public void insert(INDArray point) {
|
|
||||||
if (!point.isVector() || point.length() != dims)
|
|
||||||
throw new IllegalArgumentException("Point must be a vector of length " + dims);
|
|
||||||
|
|
||||||
if (root == null) {
|
|
||||||
root = new KDNode(point);
|
|
||||||
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
|
|
||||||
} else {
|
|
||||||
int disc = 0;
|
|
||||||
KDNode node = root;
|
|
||||||
KDNode insert = new KDNode(point);
|
|
||||||
int successor;
|
|
||||||
while (true) {
|
|
||||||
//exactly equal
|
|
||||||
INDArray pt = node.getPoint();
|
|
||||||
INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z();
|
|
||||||
if (countEq.getInt(0) == 0) {
|
|
||||||
return;
|
|
||||||
} else {
|
|
||||||
successor = successor(node, point, disc);
|
|
||||||
KDNode child;
|
|
||||||
if (successor < 1)
|
|
||||||
child = node.getLeft();
|
|
||||||
else
|
|
||||||
child = node.getRight();
|
|
||||||
if (child == null)
|
|
||||||
break;
|
|
||||||
disc = (disc + 1) % dims;
|
|
||||||
node = child;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (successor < 1)
|
|
||||||
node.setLeft(insert);
|
|
||||||
|
|
||||||
else
|
|
||||||
node.setRight(insert);
|
|
||||||
|
|
||||||
rect.enlargeTo(point);
|
|
||||||
insert.setParent(node);
|
|
||||||
}
|
|
||||||
size++;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public INDArray delete(INDArray point) {
|
|
||||||
KDNode node = root;
|
|
||||||
int _disc = 0;
|
|
||||||
while (node != null) {
|
|
||||||
if (node.point == point)
|
|
||||||
break;
|
|
||||||
int successor = successor(node, point, _disc);
|
|
||||||
if (successor < 1)
|
|
||||||
node = node.getLeft();
|
|
||||||
else
|
|
||||||
node = node.getRight();
|
|
||||||
_disc = (_disc + 1) % dims;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (node != null) {
|
|
||||||
if (node == root) {
|
|
||||||
root = delete(root, _disc);
|
|
||||||
} else
|
|
||||||
node = delete(node, _disc);
|
|
||||||
size--;
|
|
||||||
if (size == 1) {
|
|
||||||
rect = new HyperRect(HyperRect.point(point));
|
|
||||||
} else if (size == 0)
|
|
||||||
rect = null;
|
|
||||||
|
|
||||||
}
|
|
||||||
return node.getPoint();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Share this data for recursive calls of "knn"
|
|
||||||
private float currentDistance;
|
|
||||||
private INDArray currentPoint;
|
|
||||||
private INDArray minDistance = Nd4j.scalar(0.f);
|
|
||||||
|
|
||||||
|
|
||||||
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
|
|
||||||
List<Pair<Float, INDArray>> best = new ArrayList<>();
|
|
||||||
currentDistance = distance;
|
|
||||||
currentPoint = point;
|
|
||||||
knn(root, rect, best, 0);
|
|
||||||
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
|
|
||||||
@Override
|
|
||||||
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
|
|
||||||
return Float.compare(o1.getKey(), o2.getKey());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return best;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
|
|
||||||
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
|
|
||||||
return;
|
|
||||||
int _discNext = (_disc + 1) % dims;
|
|
||||||
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
|
|
||||||
.floatValue();
|
|
||||||
|
|
||||||
if (distance <= currentDistance) {
|
|
||||||
best.add(Pair.of(distance, node.getPoint()));
|
|
||||||
}
|
|
||||||
|
|
||||||
HyperRect lower = rect.getLower(node.point, _disc);
|
|
||||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
|
||||||
knn(node.getLeft(), lower, best, _discNext);
|
|
||||||
knn(node.getRight(), upper, best, _discNext);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query for nearest neighbor. Returns the distance and point
|
|
||||||
* @param point the point to query for
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Pair<Double, INDArray> nn(INDArray point) {
|
|
||||||
return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
|
|
||||||
int _disc) {
|
|
||||||
if (node == null || rect.minDistance(point, minDistance) > dist)
|
|
||||||
return Pair.of(Double.POSITIVE_INFINITY, null);
|
|
||||||
|
|
||||||
int _discNext = (_disc + 1) % dims;
|
|
||||||
double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue();
|
|
||||||
if (dist2 < dist) {
|
|
||||||
best = node.getPoint();
|
|
||||||
dist = dist2;
|
|
||||||
}
|
|
||||||
|
|
||||||
HyperRect lower = rect.getLower(node.point, _disc);
|
|
||||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
|
||||||
|
|
||||||
if (point.getDouble(_disc) < node.point.getDouble(_disc)) {
|
|
||||||
Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext);
|
|
||||||
Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext);
|
|
||||||
if (left.getKey() < dist)
|
|
||||||
return left;
|
|
||||||
else if (right.getKey() < dist)
|
|
||||||
return right;
|
|
||||||
|
|
||||||
} else {
|
|
||||||
Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext);
|
|
||||||
Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext);
|
|
||||||
if (left.getKey() < dist)
|
|
||||||
return left;
|
|
||||||
else if (right.getKey() < dist)
|
|
||||||
return right;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Pair.of(dist, best);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private KDNode delete(KDNode delete, int _disc) {
|
|
||||||
if (delete.getLeft() != null && delete.getRight() != null) {
|
|
||||||
if (delete.getParent() != null) {
|
|
||||||
if (delete.getParent().getLeft() == delete)
|
|
||||||
delete.getParent().setLeft(null);
|
|
||||||
else
|
|
||||||
delete.getParent().setRight(null);
|
|
||||||
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
int disc = _disc;
|
|
||||||
_disc = (_disc + 1) % dims;
|
|
||||||
Pair<KDNode, Integer> qd = null;
|
|
||||||
if (delete.getRight() != null) {
|
|
||||||
qd = min(delete.getRight(), disc, _disc);
|
|
||||||
} else if (delete.getLeft() != null)
|
|
||||||
qd = max(delete.getLeft(), disc, _disc);
|
|
||||||
if (qd == null) {// is leaf
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
delete.point = qd.getKey().point;
|
|
||||||
KDNode qFather = qd.getKey().getParent();
|
|
||||||
if (qFather.getLeft() == qd.getKey()) {
|
|
||||||
qFather.setLeft(delete(qd.getKey(), disc));
|
|
||||||
} else if (qFather.getRight() == qd.getKey()) {
|
|
||||||
qFather.setRight(delete(qd.getKey(), disc));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return delete;
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) {
|
|
||||||
int discNext = (_disc + 1) % dims;
|
|
||||||
if (_disc == disc) {
|
|
||||||
KDNode child = node.getLeft();
|
|
||||||
if (child != null) {
|
|
||||||
return max(child, disc, discNext);
|
|
||||||
}
|
|
||||||
} else if (node.getLeft() != null || node.getRight() != null) {
|
|
||||||
Pair<KDNode, Integer> left = null, right = null;
|
|
||||||
if (node.getLeft() != null)
|
|
||||||
left = max(node.getLeft(), disc, discNext);
|
|
||||||
if (node.getRight() != null)
|
|
||||||
right = max(node.getRight(), disc, discNext);
|
|
||||||
if (left != null && right != null) {
|
|
||||||
double pointLeft = left.getKey().getPoint().getDouble(disc);
|
|
||||||
double pointRight = right.getKey().getPoint().getDouble(disc);
|
|
||||||
if (pointLeft > pointRight)
|
|
||||||
return left;
|
|
||||||
else
|
|
||||||
return right;
|
|
||||||
} else if (left != null)
|
|
||||||
return left;
|
|
||||||
else
|
|
||||||
return right;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Pair.of(node, _disc);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) {
|
|
||||||
int discNext = (_disc + 1) % dims;
|
|
||||||
if (_disc == disc) {
|
|
||||||
KDNode child = node.getLeft();
|
|
||||||
if (child != null) {
|
|
||||||
return min(child, disc, discNext);
|
|
||||||
}
|
|
||||||
} else if (node.getLeft() != null || node.getRight() != null) {
|
|
||||||
Pair<KDNode, Integer> left = null, right = null;
|
|
||||||
if (node.getLeft() != null)
|
|
||||||
left = min(node.getLeft(), disc, discNext);
|
|
||||||
if (node.getRight() != null)
|
|
||||||
right = min(node.getRight(), disc, discNext);
|
|
||||||
if (left != null && right != null) {
|
|
||||||
double pointLeft = left.getKey().getPoint().getDouble(disc);
|
|
||||||
double pointRight = right.getKey().getPoint().getDouble(disc);
|
|
||||||
if (pointLeft < pointRight)
|
|
||||||
return left;
|
|
||||||
else
|
|
||||||
return right;
|
|
||||||
} else if (left != null)
|
|
||||||
return left;
|
|
||||||
else
|
|
||||||
return right;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Pair.of(node, _disc);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of elements in the tree
|
|
||||||
* @return the number of elements in the tree
|
|
||||||
*/
|
|
||||||
public int size() {
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
private int successor(KDNode node, INDArray point, int disc) {
|
|
||||||
for (int i = disc; i < dims; i++) {
|
|
||||||
double pointI = point.getDouble(i);
|
|
||||||
double nodePointI = node.getPoint().getDouble(i);
|
|
||||||
if (pointI < nodePointI)
|
|
||||||
return LESS;
|
|
||||||
else if (pointI > nodePointI)
|
|
||||||
return GREATER;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new IllegalStateException("Point is equal!");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private static class KDNode {
|
|
||||||
private INDArray point;
|
|
||||||
private KDNode left, right, parent;
|
|
||||||
|
|
||||||
public KDNode(INDArray point) {
|
|
||||||
this.point = point;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getPoint() {
|
|
||||||
return point;
|
|
||||||
}
|
|
||||||
|
|
||||||
public KDNode getLeft() {
|
|
||||||
return left;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLeft(KDNode left) {
|
|
||||||
this.left = left;
|
|
||||||
}
|
|
||||||
|
|
||||||
public KDNode getRight() {
|
|
||||||
return right;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setRight(KDNode right) {
|
|
||||||
this.right = right;
|
|
||||||
}
|
|
||||||
|
|
||||||
public KDNode getParent() {
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setParent(KDNode parent) {
|
|
||||||
this.parent = parent;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,109 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kmeans;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
|
|
||||||
import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy;
|
|
||||||
|
|
||||||
|
|
||||||
public class KMeansClustering extends BaseClusteringAlgorithm {
|
|
||||||
|
|
||||||
private static final long serialVersionUID = 8476951388145944776L;
|
|
||||||
private static final double VARIATION_TOLERANCE= 1e-4;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusteringStrategy
|
|
||||||
*/
|
|
||||||
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
|
|
||||||
super(clusteringStrategy, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup a kmeans instance
|
|
||||||
* @param clusterCount the number of clusters
|
|
||||||
* @param maxIterationCount the max number of iterations
|
|
||||||
* to run kmeans
|
|
||||||
* @param distanceFunction the distance function to use for grouping
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
|
|
||||||
boolean inverse, boolean useKMeansPlusPlus) {
|
|
||||||
ClusteringStrategy clusteringStrategy =
|
|
||||||
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
|
|
||||||
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
|
|
||||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterCount
|
|
||||||
* @param minDistributionVariationRate
|
|
||||||
* @param distanceFunction
|
|
||||||
* @param allowEmptyClusters
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
|
||||||
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
|
|
||||||
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
|
||||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Setup a kmeans instance
|
|
||||||
* @param clusterCount the number of clusters
|
|
||||||
* @param maxIterationCount the max number of iterations
|
|
||||||
* to run kmeans
|
|
||||||
* @param distanceFunction the distance function to use for grouping
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
|
|
||||||
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterCount
|
|
||||||
* @param minDistributionVariationRate
|
|
||||||
* @param distanceFunction
|
|
||||||
* @param allowEmptyClusters
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
|
||||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
|
||||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
|
||||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
|
|
||||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
|
||||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
|
||||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
|
|
||||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,88 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.lsh;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
public interface LSH {
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns an instance of the distance measure associated to the LSH family of this implementation.
|
|
||||||
* Beware, hashing families and their amplification constructs are distance-specific.
|
|
||||||
*/
|
|
||||||
String getDistanceMeasure();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction
|
|
||||||
*
|
|
||||||
* denoting hashLength by h,
|
|
||||||
* amplifies a (d1, d2, p1, p2) hash family into a
|
|
||||||
* (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h)
|
|
||||||
*
|
|
||||||
* @return the length of the hash in the AND construction used by this index
|
|
||||||
*/
|
|
||||||
int getHashLength();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* denoting numTables by n,
|
|
||||||
* amplifies a (d1, d2, p1, p2) hash family into a
|
|
||||||
* (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n)
|
|
||||||
*
|
|
||||||
* @return the # of hash tables in the OR construction used by this index
|
|
||||||
*/
|
|
||||||
int getNumTables();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return The dimension of the index vectors and queries
|
|
||||||
*/
|
|
||||||
int getInDimension();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populates the index with data vectors.
|
|
||||||
* @param data the vectors to index
|
|
||||||
*/
|
|
||||||
void makeIndex(INDArray data);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the set of all vectors that could approximately be considered negihbors of the query,
|
|
||||||
* without selection on the basis of distance or number of neighbors.
|
|
||||||
* @param query a vector to find neighbors for
|
|
||||||
* @return its approximate neighbors, unfiltered
|
|
||||||
*/
|
|
||||||
INDArray bucket(INDArray query);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the approximate neighbors within a distance bound.
|
|
||||||
* @param query a vector to find neighbors for
|
|
||||||
* @param maxRange the maximum distance between results and the query
|
|
||||||
* @return approximate neighbors within the distance bounds
|
|
||||||
*/
|
|
||||||
INDArray search(INDArray query, double maxRange);
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the approximate neighbors within a k-closest bound
|
|
||||||
* @param query a vector to find neighbors for
|
|
||||||
* @param k the maximum number of closest neighbors to return
|
|
||||||
* @return at most k neighbors of the query, ordered by increasing distance
|
|
||||||
*/
|
|
||||||
INDArray search(INDArray query, int k);
|
|
||||||
}
|
|
|
@ -1,227 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.lsh;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import lombok.val;
|
|
||||||
import org.nd4j.common.base.Preconditions;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
|
||||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
|
||||||
import org.nd4j.linalg.api.rng.Random;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
|
||||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
|
||||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
|
|
||||||
|
|
||||||
public class RandomProjectionLSH implements LSH {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getDistanceMeasure(){
|
|
||||||
return "cosinedistance";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Getter private int hashLength;
|
|
||||||
|
|
||||||
@Getter private int numTables;
|
|
||||||
|
|
||||||
@Getter private int inDimension;
|
|
||||||
|
|
||||||
|
|
||||||
@Getter private double radius;
|
|
||||||
|
|
||||||
INDArray randomProjection;
|
|
||||||
|
|
||||||
INDArray index;
|
|
||||||
|
|
||||||
INDArray indexData;
|
|
||||||
|
|
||||||
|
|
||||||
private INDArray gaussianRandomMatrix(int[] shape, Random rng){
|
|
||||||
INDArray res = Nd4j.create(shape);
|
|
||||||
|
|
||||||
GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
|
|
||||||
|
|
||||||
Nd4j.getExecutioner().exec(op1, rng);
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){
|
|
||||||
this(hashLength, numTables, inDimension, radius, Nd4j.getRandom());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates a locality-sensitive hashing index for the cosine distance,
|
|
||||||
* a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification
|
|
||||||
*
|
|
||||||
* @param hashLength the length of the compared hash in an AND construction,
|
|
||||||
* @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple
|
|
||||||
* probes of Panigraphi (op. cit).
|
|
||||||
* @param inDimension the dimendionality of the points being indexed
|
|
||||||
* @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction
|
|
||||||
* @param rng a Random object to draw samples from
|
|
||||||
*/
|
|
||||||
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){
|
|
||||||
this.hashLength = hashLength;
|
|
||||||
this.numTables = numTables;
|
|
||||||
this.inDimension = inDimension;
|
|
||||||
this.radius = radius;
|
|
||||||
randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* This picks uniformaly distributed random points on the unit of a sphere using the method of:
|
|
||||||
*
|
|
||||||
* An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere
|
|
||||||
* JS Hicks, RF Wheeling - Communications of the ACM, 1959
|
|
||||||
* @param data a query to generate multiple probes for
|
|
||||||
* @return `numTables`
|
|
||||||
*/
|
|
||||||
public INDArray entropy(INDArray data){
|
|
||||||
|
|
||||||
INDArray data2 =
|
|
||||||
Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius));
|
|
||||||
|
|
||||||
INDArray norms = Nd4j.norm2(data2.dup(), -1);
|
|
||||||
|
|
||||||
Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms);
|
|
||||||
|
|
||||||
data2.diviColumnVector(norms);
|
|
||||||
data2.addiRowVector(data);
|
|
||||||
return data2;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns hash values for a particular query
|
|
||||||
* @param data a query vector
|
|
||||||
* @return its hashed value
|
|
||||||
*/
|
|
||||||
public INDArray hash(INDArray data) {
|
|
||||||
if (data.shape()[1] != inDimension){
|
|
||||||
throw new ND4JIllegalStateException(
|
|
||||||
String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d",
|
|
||||||
Arrays.toString(data.shape()), inDimension));
|
|
||||||
}
|
|
||||||
INDArray projected = data.mmul(randomProjection);
|
|
||||||
INDArray res = Nd4j.getExecutioner().exec(new Sign(projected));
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it.
|
|
||||||
* @param data the vectors to index
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public void makeIndex(INDArray data) {
|
|
||||||
index = hash(data);
|
|
||||||
indexData = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
// data elements in the same bucket as the query, without entropy
|
|
||||||
INDArray rawBucketOf(INDArray query){
|
|
||||||
INDArray pattern = hash(query);
|
|
||||||
|
|
||||||
INDArray res = Nd4j.zeros(DataType.BOOL, index.shape());
|
|
||||||
Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1));
|
|
||||||
return res.castTo(Nd4j.defaultFloatingPointType()).min(-1);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray bucket(INDArray query) {
|
|
||||||
INDArray queryRes = rawBucketOf(query);
|
|
||||||
|
|
||||||
if(numTables > 1) {
|
|
||||||
INDArray entropyQueries = entropy(query);
|
|
||||||
|
|
||||||
// loop, addi + conditionalreplace -> poor man's OR function
|
|
||||||
for (int i = 0; i < numTables; i++) {
|
|
||||||
INDArray row = entropyQueries.getRow(i, true);
|
|
||||||
queryRes.addi(rawBucketOf(row));
|
|
||||||
}
|
|
||||||
BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
return queryRes;
|
|
||||||
}
|
|
||||||
|
|
||||||
// data elements in the same entropy bucket as the query,
|
|
||||||
INDArray bucketData(INDArray query){
|
|
||||||
INDArray mask = bucket(query);
|
|
||||||
int nRes = mask.sum(0).getInt(0);
|
|
||||||
INDArray res = Nd4j.create(new int[] {nRes, inDimension});
|
|
||||||
int j = 0;
|
|
||||||
for (int i = 0; i < nRes; i++){
|
|
||||||
while (mask.getInt(j) == 0 && j < mask.length() - 1) {
|
|
||||||
j += 1;
|
|
||||||
}
|
|
||||||
if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j));
|
|
||||||
j += 1;
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray search(INDArray query, double maxRange) {
|
|
||||||
if (maxRange < 0)
|
|
||||||
throw new IllegalArgumentException("ANN search should have a positive maximum search radius");
|
|
||||||
|
|
||||||
INDArray bucketData = bucketData(query);
|
|
||||||
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
|
|
||||||
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
|
|
||||||
|
|
||||||
INDArray shuffleIndexes = idxs[0];
|
|
||||||
INDArray sortedDistances = idxs[1];
|
|
||||||
int accepted = 0;
|
|
||||||
while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1;
|
|
||||||
|
|
||||||
INDArray res = Nd4j.create(new int[] {accepted, inDimension});
|
|
||||||
for(int i = 0; i < accepted; i++){
|
|
||||||
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public INDArray search(INDArray query, int k) {
|
|
||||||
if (k < 1)
|
|
||||||
throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor");
|
|
||||||
|
|
||||||
INDArray bucketData = bucketData(query);
|
|
||||||
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
|
|
||||||
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
|
|
||||||
|
|
||||||
INDArray shuffleIndexes = idxs[0];
|
|
||||||
INDArray sortedDistances = idxs[1];
|
|
||||||
val accepted = Math.min(k, sortedDistances.shape()[1]);
|
|
||||||
|
|
||||||
INDArray res = Nd4j.create(accepted, inDimension);
|
|
||||||
for(int i = 0; i < accepted; i++){
|
|
||||||
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
|
|
||||||
}
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.optimisation;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class ClusteringOptimization implements Serializable {
|
|
||||||
|
|
||||||
private ClusteringOptimizationType type;
|
|
||||||
private double value;
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,28 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.optimisation;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public enum ClusteringOptimizationType {
|
|
||||||
MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT
|
|
||||||
}
|
|
|
@ -1,115 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.quadtree;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
public class Cell implements Serializable {
|
|
||||||
private double x, y, hw, hh;
|
|
||||||
|
|
||||||
public Cell(double x, double y, double hw, double hh) {
|
|
||||||
this.x = x;
|
|
||||||
this.y = y;
|
|
||||||
this.hw = hw;
|
|
||||||
this.hh = hh;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Whether the given point is contained
|
|
||||||
* within this cell
|
|
||||||
* @param point the point to check
|
|
||||||
* @return true if the point is contained, false otherwise
|
|
||||||
*/
|
|
||||||
public boolean containsPoint(INDArray point) {
|
|
||||||
double first = point.getDouble(0), second = point.getDouble(1);
|
|
||||||
return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (!(o instanceof Cell))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
Cell cell = (Cell) o;
|
|
||||||
|
|
||||||
if (Double.compare(cell.hh, hh) != 0)
|
|
||||||
return false;
|
|
||||||
if (Double.compare(cell.hw, hw) != 0)
|
|
||||||
return false;
|
|
||||||
if (Double.compare(cell.x, x) != 0)
|
|
||||||
return false;
|
|
||||||
return Double.compare(cell.y, y) == 0;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result;
|
|
||||||
long temp;
|
|
||||||
temp = Double.doubleToLongBits(x);
|
|
||||||
result = (int) (temp ^ (temp >>> 32));
|
|
||||||
temp = Double.doubleToLongBits(y);
|
|
||||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
|
||||||
temp = Double.doubleToLongBits(hw);
|
|
||||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
|
||||||
temp = Double.doubleToLongBits(hh);
|
|
||||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getX() {
|
|
||||||
return x;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setX(double x) {
|
|
||||||
this.x = x;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getY() {
|
|
||||||
return y;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setY(double y) {
|
|
||||||
this.y = y;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getHw() {
|
|
||||||
return hw;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setHw(double hw) {
|
|
||||||
this.hw = hw;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getHh() {
|
|
||||||
return hh;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setHh(double hh) {
|
|
||||||
this.hh = hh;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,383 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.quadtree;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
|
||||||
import org.apache.commons.math3.util.FastMath;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
import static java.lang.Math.max;
|
|
||||||
|
|
||||||
public class QuadTree implements Serializable {
|
|
||||||
private QuadTree parent, northWest, northEast, southWest, southEast;
|
|
||||||
private boolean isLeaf = true;
|
|
||||||
private int size, cumSize;
|
|
||||||
private Cell boundary;
|
|
||||||
static final int QT_NO_DIMS = 2;
|
|
||||||
static final int QT_NODE_CAPACITY = 1;
|
|
||||||
private INDArray buf = Nd4j.create(QT_NO_DIMS);
|
|
||||||
private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS);
|
|
||||||
private int[] index = new int[QT_NODE_CAPACITY];
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Pass in a matrix
|
|
||||||
* @param data
|
|
||||||
*/
|
|
||||||
public QuadTree(INDArray data) {
|
|
||||||
INDArray meanY = data.mean(0);
|
|
||||||
INDArray minY = data.min(0);
|
|
||||||
INDArray maxY = data.max(0);
|
|
||||||
init(data, meanY.getDouble(0), meanY.getDouble(1),
|
|
||||||
max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0))
|
|
||||||
+ Nd4j.EPS_THRESHOLD,
|
|
||||||
max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1))
|
|
||||||
+ Nd4j.EPS_THRESHOLD);
|
|
||||||
fill();
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree(QuadTree parent, INDArray data, Cell boundary) {
|
|
||||||
this.parent = parent;
|
|
||||||
this.boundary = boundary;
|
|
||||||
this.data = data;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree(Cell boundary) {
|
|
||||||
this.boundary = boundary;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void init(INDArray data, double x, double y, double hw, double hh) {
|
|
||||||
boundary = new Cell(x, y, hw, hh);
|
|
||||||
this.data = data;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void fill() {
|
|
||||||
for (int i = 0; i < data.rows(); i++)
|
|
||||||
insert(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the cell of this element
|
|
||||||
*
|
|
||||||
* @param coordinates
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
protected QuadTree findIndex(INDArray coordinates) {
|
|
||||||
|
|
||||||
// Compute the sector for the coordinates
|
|
||||||
boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2));
|
|
||||||
boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2));
|
|
||||||
|
|
||||||
// top left
|
|
||||||
QuadTree index = getNorthWest();
|
|
||||||
if (left) {
|
|
||||||
// left side
|
|
||||||
if (!top) {
|
|
||||||
// bottom left
|
|
||||||
index = getSouthWest();
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// right side
|
|
||||||
if (top) {
|
|
||||||
// top right
|
|
||||||
index = getNorthEast();
|
|
||||||
} else {
|
|
||||||
// bottom right
|
|
||||||
index = getSouthEast();
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Insert an index of the data in to the tree
|
|
||||||
* @param newIndex the index to insert in to the tree
|
|
||||||
* @return whether the index was inserted or not
|
|
||||||
*/
|
|
||||||
public boolean insert(int newIndex) {
|
|
||||||
// Ignore objects which do not belong in this quad tree
|
|
||||||
INDArray point = data.slice(newIndex);
|
|
||||||
if (!boundary.containsPoint(point))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
cumSize++;
|
|
||||||
double mult1 = (double) (cumSize - 1) / (double) cumSize;
|
|
||||||
double mult2 = 1.0 / (double) cumSize;
|
|
||||||
|
|
||||||
centerOfMass.muli(mult1);
|
|
||||||
centerOfMass.addi(point.mul(mult2));
|
|
||||||
|
|
||||||
// If there is space in this quad tree and it is a leaf, add the object here
|
|
||||||
if (isLeaf() && size < QT_NODE_CAPACITY) {
|
|
||||||
index[size] = newIndex;
|
|
||||||
size++;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
//duplicate point
|
|
||||||
if (size > 0) {
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
INDArray compPoint = data.slice(index[i]);
|
|
||||||
if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// If this Node has already been subdivided just add the elements to the
|
|
||||||
// appropriate cell
|
|
||||||
if (!isLeaf()) {
|
|
||||||
QuadTree index = findIndex(point);
|
|
||||||
index.insert(newIndex);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (isLeaf())
|
|
||||||
subDivide();
|
|
||||||
|
|
||||||
boolean ret = insertIntoOneOf(newIndex);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean insertIntoOneOf(int index) {
|
|
||||||
boolean success = false;
|
|
||||||
success = northWest.insert(index);
|
|
||||||
if (!success)
|
|
||||||
success = northEast.insert(index);
|
|
||||||
if (!success)
|
|
||||||
success = southWest.insert(index);
|
|
||||||
if (!success)
|
|
||||||
success = southEast.insert(index);
|
|
||||||
return success;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns whether the tree is consistent or not
|
|
||||||
* @return whether the tree is consistent or not
|
|
||||||
*/
|
|
||||||
public boolean isCorrect() {
|
|
||||||
|
|
||||||
for (int n = 0; n < size; n++) {
|
|
||||||
INDArray point = data.slice(index[n]);
|
|
||||||
if (!boundary.containsPoint(point))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect()
|
|
||||||
&& southEast.isCorrect();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create four children
|
|
||||||
* which fully divide this cell
|
|
||||||
* into four quads of equal area
|
|
||||||
*/
|
|
||||||
public void subDivide() {
|
|
||||||
northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
|
|
||||||
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
|
||||||
northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
|
|
||||||
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
|
||||||
southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
|
|
||||||
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
|
||||||
southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
|
|
||||||
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute non edge forces using barnes hut
|
|
||||||
* @param pointIndex
|
|
||||||
* @param theta
|
|
||||||
* @param negativeForce
|
|
||||||
* @param sumQ
|
|
||||||
*/
|
|
||||||
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
|
|
||||||
// Make sure that we spend no time on empty nodes or self-interactions
|
|
||||||
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
|
|
||||||
return;
|
|
||||||
|
|
||||||
|
|
||||||
// Compute distance between point and center-of-mass
|
|
||||||
buf.assign(data.slice(pointIndex)).subi(centerOfMass);
|
|
||||||
|
|
||||||
double D = Nd4j.getBlasWrapper().dot(buf, buf);
|
|
||||||
|
|
||||||
// Check whether we can use this node as a "summary"
|
|
||||||
if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) {
|
|
||||||
|
|
||||||
// Compute and add t-SNE force between point and current node
|
|
||||||
double Q = 1.0 / (1.0 + D);
|
|
||||||
double mult = cumSize * Q;
|
|
||||||
sumQ.addAndGet(mult);
|
|
||||||
mult *= Q;
|
|
||||||
negativeForce.addi(buf.mul(mult));
|
|
||||||
|
|
||||||
} else {
|
|
||||||
|
|
||||||
// Recursively apply Barnes-Hut to children
|
|
||||||
northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
|
||||||
northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
|
||||||
southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
|
||||||
southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param rowP a vector
|
|
||||||
* @param colP
|
|
||||||
* @param valP
|
|
||||||
* @param N
|
|
||||||
* @param posF
|
|
||||||
*/
|
|
||||||
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
|
|
||||||
if (!rowP.isVector())
|
|
||||||
throw new IllegalArgumentException("RowP must be a vector");
|
|
||||||
|
|
||||||
// Loop over all edges in the graph
|
|
||||||
double D;
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
|
|
||||||
|
|
||||||
// Compute pairwise distance and Q-value
|
|
||||||
buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i)));
|
|
||||||
|
|
||||||
D = Nd4j.getBlasWrapper().dot(buf, buf);
|
|
||||||
D = valP.getDouble(i) / D;
|
|
||||||
|
|
||||||
// Sum positive force
|
|
||||||
posF.slice(n).addi(buf.mul(D));
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The depth of the node
|
|
||||||
* @return the depth of the node
|
|
||||||
*/
|
|
||||||
public int depth() {
|
|
||||||
if (isLeaf())
|
|
||||||
return 1;
|
|
||||||
return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth()));
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getCenterOfMass() {
|
|
||||||
return centerOfMass;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCenterOfMass(INDArray centerOfMass) {
|
|
||||||
this.centerOfMass = centerOfMass;
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree getParent() {
|
|
||||||
return parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setParent(QuadTree parent) {
|
|
||||||
this.parent = parent;
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree getNorthWest() {
|
|
||||||
return northWest;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNorthWest(QuadTree northWest) {
|
|
||||||
this.northWest = northWest;
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree getNorthEast() {
|
|
||||||
return northEast;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNorthEast(QuadTree northEast) {
|
|
||||||
this.northEast = northEast;
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree getSouthWest() {
|
|
||||||
return southWest;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSouthWest(QuadTree southWest) {
|
|
||||||
this.southWest = southWest;
|
|
||||||
}
|
|
||||||
|
|
||||||
public QuadTree getSouthEast() {
|
|
||||||
return southEast;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSouthEast(QuadTree southEast) {
|
|
||||||
this.southEast = southEast;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return isLeaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setLeaf(boolean isLeaf) {
|
|
||||||
this.isLeaf = isLeaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getSize() {
|
|
||||||
return size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setSize(int size) {
|
|
||||||
this.size = size;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getCumSize() {
|
|
||||||
return cumSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCumSize(int cumSize) {
|
|
||||||
this.cumSize = cumSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Cell getBoundary() {
|
|
||||||
return boundary;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setBoundary(Cell boundary) {
|
|
||||||
this.boundary = boundary;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,104 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.randomprojection;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
public class RPForest {
|
|
||||||
|
|
||||||
private int numTrees;
|
|
||||||
private List<RPTree> trees;
|
|
||||||
private INDArray data;
|
|
||||||
private int maxSize = 1000;
|
|
||||||
private String similarityFunction;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create the rp forest with the specified number of trees
|
|
||||||
* @param numTrees the number of trees in the forest
|
|
||||||
* @param maxSize the max size of each tree
|
|
||||||
* @param similarityFunction the distance function to use
|
|
||||||
*/
|
|
||||||
public RPForest(int numTrees,int maxSize,String similarityFunction) {
|
|
||||||
this.numTrees = numTrees;
|
|
||||||
this.maxSize = maxSize;
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
trees = new ArrayList<>(numTrees);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the trees from the given dataset
|
|
||||||
* @param x the input dataset (should be a 2d matrix)
|
|
||||||
*/
|
|
||||||
public void fit(INDArray x) {
|
|
||||||
this.data = x;
|
|
||||||
for(int i = 0; i < numTrees; i++) {
|
|
||||||
RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction);
|
|
||||||
tree.buildTree(x);
|
|
||||||
trees.add(tree);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get all candidates relative to a specific datapoint.
|
|
||||||
* @param input
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public INDArray getAllCandidates(INDArray input) {
|
|
||||||
return RPUtils.getAllCandidates(input,trees,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query results up to length n
|
|
||||||
* nearest neighbors
|
|
||||||
* @param toQuery the query item
|
|
||||||
* @param n the number of nearest neighbors for the given data point
|
|
||||||
* @return the indices for the nearest neighbors
|
|
||||||
*/
|
|
||||||
public INDArray queryAll(INDArray toQuery,int n) {
|
|
||||||
return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query all with the distances
|
|
||||||
* sorted by index
|
|
||||||
* @param query the query vector
|
|
||||||
* @param numResults the number of results to return
|
|
||||||
* @return a list of samples
|
|
||||||
*/
|
|
||||||
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
|
|
||||||
return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,57 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.randomprojection;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RPHyperPlanes {
|
|
||||||
private int dim;
|
|
||||||
private INDArray wholeHyperPlane;
|
|
||||||
|
|
||||||
public RPHyperPlanes(int dim) {
|
|
||||||
this.dim = dim;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getHyperPlaneAt(int depth) {
|
|
||||||
if(wholeHyperPlane.isVector())
|
|
||||||
return wholeHyperPlane;
|
|
||||||
return wholeHyperPlane.slice(depth);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Add a new random element to the hyper plane.
|
|
||||||
*/
|
|
||||||
public void addRandomHyperPlane() {
|
|
||||||
INDArray newPlane = Nd4j.randn(new int[] {1,dim});
|
|
||||||
newPlane.divi(newPlane.normmaxNumber());
|
|
||||||
if(wholeHyperPlane == null)
|
|
||||||
wholeHyperPlane = newPlane;
|
|
||||||
else {
|
|
||||||
wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,48 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.randomprojection;
|
|
||||||
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.Future;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RPNode {
|
|
||||||
private int depth;
|
|
||||||
private RPNode left,right;
|
|
||||||
private Future<RPNode> leftFuture,rightFuture;
|
|
||||||
private List<Integer> indices;
|
|
||||||
private double median;
|
|
||||||
private RPTree tree;
|
|
||||||
|
|
||||||
|
|
||||||
public RPNode(RPTree tree,int depth) {
|
|
||||||
this.depth = depth;
|
|
||||||
this.tree = tree;
|
|
||||||
indices = new ArrayList<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,130 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.randomprojection;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.*;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.ExecutorService;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class RPTree {
|
|
||||||
private RPNode root;
|
|
||||||
private RPHyperPlanes rpHyperPlanes;
|
|
||||||
private int dim;
|
|
||||||
//also knows as leave size
|
|
||||||
private int maxSize;
|
|
||||||
private INDArray X;
|
|
||||||
private String similarityFunction = "euclidean";
|
|
||||||
private WorkspaceConfiguration workspaceConfiguration;
|
|
||||||
private ExecutorService searchExecutor;
|
|
||||||
private int searchWorkers;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param dim the dimension of the vectors
|
|
||||||
* @param maxSize the max size of the leaves
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@Builder
|
|
||||||
public RPTree(int dim, int maxSize,String similarityFunction) {
|
|
||||||
this.dim = dim;
|
|
||||||
this.maxSize = maxSize;
|
|
||||||
rpHyperPlanes = new RPHyperPlanes(dim);
|
|
||||||
root = new RPNode(this,0);
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
|
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
|
|
||||||
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
|
|
||||||
.policySpill(SpillPolicy.REALLOCATE).build();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param dim the dimension of the vectors
|
|
||||||
* @param maxSize the max size of the leaves
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public RPTree(int dim, int maxSize) {
|
|
||||||
this(dim,maxSize,"euclidean");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Build the tree with the given input data
|
|
||||||
* @param x
|
|
||||||
*/
|
|
||||||
|
|
||||||
public void buildTree(INDArray x) {
|
|
||||||
this.X = x;
|
|
||||||
for(int i = 0; i < x.rows(); i++) {
|
|
||||||
root.getIndices().add(i);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
RPUtils.buildTree(this,root,rpHyperPlanes,
|
|
||||||
x,maxSize,0,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public void addNodeAtIndex(int idx,INDArray toAdd) {
|
|
||||||
RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction);
|
|
||||||
query.getIndices().add(idx);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public List<RPNode> getLeaves() {
|
|
||||||
List<RPNode> nodes = new ArrayList<>();
|
|
||||||
RPUtils.scanForLeaves(nodes,getRoot());
|
|
||||||
return nodes;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query all with the distances
|
|
||||||
* sorted by index
|
|
||||||
* @param query the query vector
|
|
||||||
* @param numResults the number of results to return
|
|
||||||
* @return a list of samples
|
|
||||||
*/
|
|
||||||
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
|
|
||||||
return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray query(INDArray query,int numResults) {
|
|
||||||
return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<Integer> getCandidates(INDArray target) {
|
|
||||||
return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,481 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.randomprojection;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
|
||||||
import lombok.val;
|
|
||||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
public class RPUtils {
|
|
||||||
|
|
||||||
|
|
||||||
private static ThreadLocal<Map<String,DifferentialFunction>> functionInstances = new ThreadLocal<>();
|
|
||||||
|
|
||||||
public static <T extends DifferentialFunction> DifferentialFunction getOp(String name,
|
|
||||||
INDArray x,
|
|
||||||
INDArray y,
|
|
||||||
INDArray result) {
|
|
||||||
Map<String,DifferentialFunction> ops = functionInstances.get();
|
|
||||||
if(ops == null) {
|
|
||||||
ops = new HashMap<>();
|
|
||||||
functionInstances.set(ops);
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean allDistances = x.length() != y.length();
|
|
||||||
|
|
||||||
switch(name) {
|
|
||||||
case "cosinedistance":
|
|
||||||
if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances);
|
|
||||||
ops.put(name,cosineDistance);
|
|
||||||
return cosineDistance;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
CosineDistance cosineDistance = (CosineDistance) ops.get(name);
|
|
||||||
return cosineDistance;
|
|
||||||
}
|
|
||||||
case "cosinesimilarity":
|
|
||||||
if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances);
|
|
||||||
ops.put(name,cosineSimilarity);
|
|
||||||
return cosineSimilarity;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name);
|
|
||||||
cosineSimilarity.setX(x);
|
|
||||||
cosineSimilarity.setY(y);
|
|
||||||
cosineSimilarity.setZ(result);
|
|
||||||
return cosineSimilarity;
|
|
||||||
|
|
||||||
}
|
|
||||||
case "manhattan":
|
|
||||||
if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances);
|
|
||||||
ops.put(name,manhattanDistance);
|
|
||||||
return manhattanDistance;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name);
|
|
||||||
manhattanDistance.setX(x);
|
|
||||||
manhattanDistance.setY(y);
|
|
||||||
manhattanDistance.setZ(result);
|
|
||||||
return manhattanDistance;
|
|
||||||
}
|
|
||||||
case "jaccard":
|
|
||||||
if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances);
|
|
||||||
ops.put(name,jaccardDistance);
|
|
||||||
return jaccardDistance;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name);
|
|
||||||
jaccardDistance.setX(x);
|
|
||||||
jaccardDistance.setY(y);
|
|
||||||
jaccardDistance.setZ(result);
|
|
||||||
return jaccardDistance;
|
|
||||||
}
|
|
||||||
case "hamming":
|
|
||||||
if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances);
|
|
||||||
ops.put(name,hammingDistance);
|
|
||||||
return hammingDistance;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
HammingDistance hammingDistance = (HammingDistance) ops.get(name);
|
|
||||||
hammingDistance.setX(x);
|
|
||||||
hammingDistance.setY(y);
|
|
||||||
hammingDistance.setZ(result);
|
|
||||||
return hammingDistance;
|
|
||||||
}
|
|
||||||
//euclidean
|
|
||||||
default:
|
|
||||||
if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
|
||||||
EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances);
|
|
||||||
ops.put(name,euclideanDistance);
|
|
||||||
return euclideanDistance;
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name);
|
|
||||||
euclideanDistance.setX(x);
|
|
||||||
euclideanDistance.setY(y);
|
|
||||||
euclideanDistance.setZ(result);
|
|
||||||
return euclideanDistance;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query all trees using the given input and data
|
|
||||||
* @param toQuery the query vector
|
|
||||||
* @param X the input data to query
|
|
||||||
* @param trees the trees to query
|
|
||||||
* @param n the number of results to search for
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @return the indices (in order) in the ndarray
|
|
||||||
*/
|
|
||||||
public static List<Pair<Double,Integer>> queryAllWithDistances(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
|
|
||||||
if(trees.isEmpty()) {
|
|
||||||
throw new ND4JIllegalArgumentException("Trees is empty!");
|
|
||||||
}
|
|
||||||
|
|
||||||
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
|
|
||||||
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
|
|
||||||
int numReturns = Math.min(n,sortedCandidates.size());
|
|
||||||
List<Pair<Double,Integer>> ret = new ArrayList<>(numReturns);
|
|
||||||
for(int i = 0; i < numReturns; i++) {
|
|
||||||
ret.add(sortedCandidates.get(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query all trees using the given input and data
|
|
||||||
* @param toQuery the query vector
|
|
||||||
* @param X the input data to query
|
|
||||||
* @param trees the trees to query
|
|
||||||
* @param n the number of results to search for
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @return the indices (in order) in the ndarray
|
|
||||||
*/
|
|
||||||
public static INDArray queryAll(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
|
|
||||||
if(trees.isEmpty()) {
|
|
||||||
throw new ND4JIllegalArgumentException("Trees is empty!");
|
|
||||||
}
|
|
||||||
|
|
||||||
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
|
|
||||||
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
|
|
||||||
int numReturns = Math.min(n,sortedCandidates.size());
|
|
||||||
|
|
||||||
INDArray result = Nd4j.create(numReturns);
|
|
||||||
for(int i = 0; i < numReturns; i++) {
|
|
||||||
result.putScalar(i,sortedCandidates.get(i).getSecond());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the sorted distances given the
|
|
||||||
* query vector, input data, given the list of possible search candidates
|
|
||||||
* @param x the query vector
|
|
||||||
* @param X the input data to use
|
|
||||||
* @param candidates the possible search candidates
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @return the sorted distances
|
|
||||||
*/
|
|
||||||
public static List<Pair<Double,Integer>> sortCandidates(INDArray x,INDArray X,
|
|
||||||
List<Integer> candidates,
|
|
||||||
String similarityFunction) {
|
|
||||||
int prevIdx = -1;
|
|
||||||
List<Pair<Double,Integer>> ret = new ArrayList<>();
|
|
||||||
for(int i = 0; i < candidates.size(); i++) {
|
|
||||||
if(candidates.get(i) != prevIdx) {
|
|
||||||
ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i)));
|
|
||||||
}
|
|
||||||
|
|
||||||
prevIdx = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
Collections.sort(ret, new Comparator<Pair<Double, Integer>>() {
|
|
||||||
@Override
|
|
||||||
public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) {
|
|
||||||
return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst());
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the search candidates as indices given the input
|
|
||||||
* and similarity function
|
|
||||||
* @param x the input data to search with
|
|
||||||
* @param trees the trees to search
|
|
||||||
* @param similarityFunction the function to use for similarity
|
|
||||||
* @return the list of indices as the search results
|
|
||||||
*/
|
|
||||||
public static INDArray getAllCandidates(INDArray x,List<RPTree> trees,String similarityFunction) {
|
|
||||||
List<Integer> candidates = getCandidates(x,trees,similarityFunction);
|
|
||||||
Collections.sort(candidates);
|
|
||||||
|
|
||||||
int prevIdx = -1;
|
|
||||||
int idxCount = 0;
|
|
||||||
List<Pair<Integer,Integer>> scores = new ArrayList<>();
|
|
||||||
for(int i = 0; i < candidates.size(); i++) {
|
|
||||||
if(candidates.get(i) == prevIdx) {
|
|
||||||
idxCount++;
|
|
||||||
}
|
|
||||||
else if(prevIdx != -1) {
|
|
||||||
scores.add(Pair.of(idxCount,prevIdx));
|
|
||||||
idxCount = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
prevIdx = i;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
scores.add(Pair.of(idxCount,prevIdx));
|
|
||||||
|
|
||||||
INDArray arr = Nd4j.create(scores.size());
|
|
||||||
for(int i = 0; i < scores.size(); i++) {
|
|
||||||
arr.putScalar(i,scores.get(i).getSecond());
|
|
||||||
}
|
|
||||||
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get the search candidates as indices given the input
|
|
||||||
* and similarity function
|
|
||||||
* @param x the input data to search with
|
|
||||||
* @param roots the trees to search
|
|
||||||
* @param similarityFunction the function to use for similarity
|
|
||||||
* @return the list of indices as the search results
|
|
||||||
*/
|
|
||||||
public static List<Integer> getCandidates(INDArray x,List<RPTree> roots,String similarityFunction) {
|
|
||||||
Set<Integer> ret = new LinkedHashSet<>();
|
|
||||||
for(RPTree tree : roots) {
|
|
||||||
RPNode root = tree.getRoot();
|
|
||||||
RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction);
|
|
||||||
ret.addAll(query.getIndices());
|
|
||||||
}
|
|
||||||
|
|
||||||
return new ArrayList<>(ret);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Query the tree starting from the given node
|
|
||||||
* using the given hyper plane and similarity function
|
|
||||||
* @param from the node to start from
|
|
||||||
* @param planes the hyper plane to query
|
|
||||||
* @param x the input data
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @return the leaf node representing the given query from a
|
|
||||||
* search in the tree
|
|
||||||
*/
|
|
||||||
public static RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) {
|
|
||||||
if(from.getLeft() == null && from.getRight() == null) {
|
|
||||||
return from;
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth());
|
|
||||||
double dist = computeDistance(similarityFunction,x,hyperPlane);
|
|
||||||
if(dist <= from.getMedian()) {
|
|
||||||
return query(from.getLeft(),planes,x,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
else {
|
|
||||||
return query(from.getRight(),planes,x,similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the distance between 2 vectors
|
|
||||||
* given a function name. Valid function names:
|
|
||||||
* euclidean: euclidean distance
|
|
||||||
* cosinedistance: cosine distance
|
|
||||||
* cosine similarity: cosine similarity
|
|
||||||
* manhattan: manhattan distance
|
|
||||||
* jaccard: jaccard distance
|
|
||||||
* hamming: hamming distance
|
|
||||||
* @param function the function to use (default euclidean distance)
|
|
||||||
* @param x the first vector
|
|
||||||
* @param y the second vector
|
|
||||||
* @return the distance between the 2 vectors given the inputs
|
|
||||||
*/
|
|
||||||
public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) {
|
|
||||||
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
|
|
||||||
op.setDimensions(1);
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
return op.z();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the distance between 2 vectors
|
|
||||||
* given a function name. Valid function names:
|
|
||||||
* euclidean: euclidean distance
|
|
||||||
* cosinedistance: cosine distance
|
|
||||||
* cosine similarity: cosine similarity
|
|
||||||
* manhattan: manhattan distance
|
|
||||||
* jaccard: jaccard distance
|
|
||||||
* hamming: hamming distance
|
|
||||||
* @param function the function to use (default euclidean distance)
|
|
||||||
* @param x the first vector
|
|
||||||
* @param y the second vector
|
|
||||||
* @return the distance between the 2 vectors given the inputs
|
|
||||||
*/
|
|
||||||
public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) {
|
|
||||||
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
return op.z().getDouble(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute the distance between 2 vectors
|
|
||||||
* given a function name. Valid function names:
|
|
||||||
* euclidean: euclidean distance
|
|
||||||
* cosinedistance: cosine distance
|
|
||||||
* cosine similarity: cosine similarity
|
|
||||||
* manhattan: manhattan distance
|
|
||||||
* jaccard: jaccard distance
|
|
||||||
* hamming: hamming distance
|
|
||||||
* @param function the function to use (default euclidean distance)
|
|
||||||
* @param x the first vector
|
|
||||||
* @param y the second vector
|
|
||||||
* @return the distance between the 2 vectors given the inputs
|
|
||||||
*/
|
|
||||||
public static double computeDistance(String function,INDArray x,INDArray y) {
|
|
||||||
return computeDistance(function,x,y,Nd4j.scalar(0.0));
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize the tree given the input parameters
|
|
||||||
* @param tree the tree to initialize
|
|
||||||
* @param from the starting node
|
|
||||||
* @param planes the hyper planes to use (vector space for similarity)
|
|
||||||
* @param X the input data
|
|
||||||
* @param maxSize the max number of indices on a given leaf node
|
|
||||||
* @param depth the current depth of the tree
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
*/
|
|
||||||
public static void buildTree(RPTree tree,
|
|
||||||
RPNode from,
|
|
||||||
RPHyperPlanes planes,
|
|
||||||
INDArray X,
|
|
||||||
int maxSize,
|
|
||||||
int depth,
|
|
||||||
String similarityFunction) {
|
|
||||||
if(from.getIndices().size() <= maxSize) {
|
|
||||||
//slimNode
|
|
||||||
slimNode(from);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
List<Double> distances = new ArrayList<>();
|
|
||||||
RPNode left = new RPNode(tree,depth + 1);
|
|
||||||
RPNode right = new RPNode(tree,depth + 1);
|
|
||||||
|
|
||||||
if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) {
|
|
||||||
planes.addRandomHyperPlane();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
INDArray hyperPlane = planes.getHyperPlaneAt(depth);
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
for(int i = 0; i < from.getIndices().size(); i++) {
|
|
||||||
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
|
|
||||||
distances.add(cosineSim);
|
|
||||||
}
|
|
||||||
|
|
||||||
Collections.sort(distances);
|
|
||||||
from.setMedian(distances.get(distances.size() / 2));
|
|
||||||
|
|
||||||
|
|
||||||
for(int i = 0; i < from.getIndices().size(); i++) {
|
|
||||||
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
|
|
||||||
if(cosineSim <= from.getMedian()) {
|
|
||||||
left.getIndices().add(from.getIndices().get(i));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
right.getIndices().add(from.getIndices().get(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//failed split
|
|
||||||
if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) {
|
|
||||||
slimNode(from);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
from.setLeft(left);
|
|
||||||
from.setRight(right);
|
|
||||||
slimNode(from);
|
|
||||||
|
|
||||||
|
|
||||||
buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction);
|
|
||||||
buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Scan for leaves accumulating
|
|
||||||
* the nodes in the passed in list
|
|
||||||
* @param nodes the nodes so far
|
|
||||||
* @param scan the tree to scan
|
|
||||||
*/
|
|
||||||
public static void scanForLeaves(List<RPNode> nodes,RPTree scan) {
|
|
||||||
scanForLeaves(nodes,scan.getRoot());
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Scan for leaves accumulating
|
|
||||||
* the nodes in the passed in list
|
|
||||||
* @param nodes the nodes so far
|
|
||||||
*/
|
|
||||||
public static void scanForLeaves(List<RPNode> nodes,RPNode current) {
|
|
||||||
if(current.getLeft() == null && current.getRight() == null)
|
|
||||||
nodes.add(current);
|
|
||||||
if(current.getLeft() != null)
|
|
||||||
scanForLeaves(nodes,current.getLeft());
|
|
||||||
if(current.getRight() != null)
|
|
||||||
scanForLeaves(nodes,current.getRight());
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Prune indices from the given node
|
|
||||||
* when it's a leaf
|
|
||||||
* @param node the node to prune
|
|
||||||
*/
|
|
||||||
public static void slimNode(RPNode node) {
|
|
||||||
if(node.getRight() != null && node.getLeft() != null) {
|
|
||||||
node.getIndices().clear();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,87 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.sptree;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class Cell implements Serializable {
|
|
||||||
private int dimension;
|
|
||||||
private INDArray corner, width;
|
|
||||||
|
|
||||||
public Cell(int dimension) {
|
|
||||||
this.dimension = dimension;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double corner(int d) {
|
|
||||||
return corner.getDouble(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double width(int d) {
|
|
||||||
return width.getDouble(d);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCorner(int d, double corner) {
|
|
||||||
this.corner.putScalar(d, corner);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWidth(int d, double width) {
|
|
||||||
this.width.putScalar(d, width);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setWidth(INDArray width) {
|
|
||||||
this.width = width;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCorner(INDArray corner) {
|
|
||||||
this.corner = corner;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public boolean contains(INDArray point) {
|
|
||||||
INDArray cornerMinusWidth = corner.sub(width);
|
|
||||||
INDArray cornerPlusWidth = corner.add(width);
|
|
||||||
for (int d = 0; d < dimension; d++) {
|
|
||||||
double pointD = point.getDouble(d);
|
|
||||||
if (cornerMinusWidth.getDouble(d) > pointD)
|
|
||||||
return false;
|
|
||||||
if (cornerPlusWidth.getDouble(d) < pointD)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray width() {
|
|
||||||
return width;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray corner() {
|
|
||||||
return corner;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,95 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.sptree;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class DataPoint implements Serializable {
|
|
||||||
private int index;
|
|
||||||
private INDArray point;
|
|
||||||
private long d;
|
|
||||||
private String functionName;
|
|
||||||
private boolean invert = false;
|
|
||||||
|
|
||||||
|
|
||||||
public DataPoint(int index, INDArray point, boolean invert) {
|
|
||||||
this(index, point, "euclidean");
|
|
||||||
this.invert = invert;
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataPoint(int index, INDArray point, String functionName, boolean invert) {
|
|
||||||
this.index = index;
|
|
||||||
this.point = point;
|
|
||||||
this.functionName = functionName;
|
|
||||||
this.d = point.length();
|
|
||||||
this.invert = invert;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public DataPoint(int index, INDArray point) {
|
|
||||||
this(index, point, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataPoint(int index, INDArray point, String functionName) {
|
|
||||||
this(index, point, functionName, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Euclidean distance
|
|
||||||
* @param point the distance from this point to the given point
|
|
||||||
* @return the distance between the two points
|
|
||||||
*/
|
|
||||||
public float distance(DataPoint point) {
|
|
||||||
switch (functionName) {
|
|
||||||
case "euclidean":
|
|
||||||
float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
|
|
||||||
.getFinalResult().floatValue();
|
|
||||||
return invert ? -ret : ret;
|
|
||||||
|
|
||||||
case "cosinesimilarity":
|
|
||||||
float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point))
|
|
||||||
.getFinalResult().floatValue();
|
|
||||||
return invert ? -ret2 : ret2;
|
|
||||||
|
|
||||||
case "manhattan":
|
|
||||||
float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point))
|
|
||||||
.getFinalResult().floatValue();
|
|
||||||
return invert ? -ret3 : ret3;
|
|
||||||
case "dot":
|
|
||||||
float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point);
|
|
||||||
return invert ? -dotRet : dotRet;
|
|
||||||
default:
|
|
||||||
float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
|
|
||||||
.getFinalResult().floatValue();
|
|
||||||
return invert ? -ret4 : ret4;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,83 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.sptree;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class HeapItem implements Serializable, Comparable<HeapItem> {
|
|
||||||
private int index;
|
|
||||||
private double distance;
|
|
||||||
|
|
||||||
|
|
||||||
public HeapItem(int index, double distance) {
|
|
||||||
this.index = index;
|
|
||||||
this.distance = distance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getIndex() {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setIndex(int index) {
|
|
||||||
this.index = index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getDistance() {
|
|
||||||
return distance;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setDistance(double distance) {
|
|
||||||
this.distance = distance;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (o == null || getClass() != o.getClass())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
HeapItem heapItem = (HeapItem) o;
|
|
||||||
|
|
||||||
if (index != heapItem.index)
|
|
||||||
return false;
|
|
||||||
return Double.compare(heapItem.distance, distance) == 0;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result;
|
|
||||||
long temp;
|
|
||||||
result = index;
|
|
||||||
temp = Double.doubleToLongBits(distance);
|
|
||||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compareTo(HeapItem o) {
|
|
||||||
return distance < o.distance ? 1 : 0;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,72 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.sptree;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public class HeapObject implements Serializable, Comparable<HeapObject> {
|
|
||||||
private int index;
|
|
||||||
private INDArray point;
|
|
||||||
private double distance;
|
|
||||||
|
|
||||||
|
|
||||||
public HeapObject(int index, INDArray point, double distance) {
|
|
||||||
this.index = index;
|
|
||||||
this.point = point;
|
|
||||||
this.distance = distance;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean equals(Object o) {
|
|
||||||
if (this == o)
|
|
||||||
return true;
|
|
||||||
if (o == null || getClass() != o.getClass())
|
|
||||||
return false;
|
|
||||||
|
|
||||||
HeapObject heapObject = (HeapObject) o;
|
|
||||||
|
|
||||||
if (!point.equals(heapObject.point))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
return Double.compare(heapObject.distance, distance) == 0;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int hashCode() {
|
|
||||||
int result;
|
|
||||||
long temp;
|
|
||||||
result = index;
|
|
||||||
temp = Double.doubleToLongBits(distance);
|
|
||||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compareTo(HeapObject o) {
|
|
||||||
return distance < o.distance ? 1 : 0;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,425 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.sptree;
|
|
||||||
|
|
||||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
|
||||||
import lombok.val;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class SpTree implements Serializable {
|
|
||||||
|
|
||||||
|
|
||||||
public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL";
|
|
||||||
|
|
||||||
|
|
||||||
private int D;
|
|
||||||
private INDArray data;
|
|
||||||
public final static int NODE_RATIO = 8000;
|
|
||||||
private int N;
|
|
||||||
private int size;
|
|
||||||
private int cumSize;
|
|
||||||
private Cell boundary;
|
|
||||||
private INDArray centerOfMass;
|
|
||||||
private SpTree parent;
|
|
||||||
private int[] index;
|
|
||||||
private int nodeCapacity;
|
|
||||||
private int numChildren = 2;
|
|
||||||
private boolean isLeaf = true;
|
|
||||||
private Collection<INDArray> indices;
|
|
||||||
private SpTree[] children;
|
|
||||||
private static Logger log = LoggerFactory.getLogger(SpTree.class);
|
|
||||||
private String similarityFunction = Distance.EUCLIDEAN.toString();
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
|
||||||
String similarityFunction) {
|
|
||||||
init(parent, data, corner, width, indices, similarityFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
|
|
||||||
this.indices = indices;
|
|
||||||
this.N = data.rows();
|
|
||||||
this.D = data.columns();
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
data = data.dup();
|
|
||||||
INDArray meanY = data.mean(0);
|
|
||||||
INDArray minY = data.min(0);
|
|
||||||
INDArray maxY = data.max(0);
|
|
||||||
INDArray width = Nd4j.create(data.dataType(), meanY.shape());
|
|
||||||
for (int i = 0; i < width.length(); i++) {
|
|
||||||
width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i),
|
|
||||||
meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD);
|
|
||||||
}
|
|
||||||
|
|
||||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
|
||||||
init(null, data, meanY, width, indices, similarityFunction);
|
|
||||||
fill(N);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
|
|
||||||
this(parent, data, corner, width, indices, "euclidean");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree(INDArray data, Collection<INDArray> indices) {
|
|
||||||
this(data, indices, "euclidean");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree(INDArray data) {
|
|
||||||
this(data, new ArrayList<INDArray>());
|
|
||||||
}
|
|
||||||
|
|
||||||
public MemoryWorkspace workspace() {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
|
||||||
String similarityFunction) {
|
|
||||||
|
|
||||||
this.parent = parent;
|
|
||||||
D = data.columns();
|
|
||||||
N = data.rows();
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
nodeCapacity = N % NODE_RATIO;
|
|
||||||
index = new int[nodeCapacity];
|
|
||||||
for (int d = 1; d < this.D; d++)
|
|
||||||
numChildren *= 2;
|
|
||||||
this.indices = indices;
|
|
||||||
isLeaf = true;
|
|
||||||
size = 0;
|
|
||||||
cumSize = 0;
|
|
||||||
children = new SpTree[numChildren];
|
|
||||||
this.data = data;
|
|
||||||
boundary = new Cell(D);
|
|
||||||
boundary.setCorner(corner.dup());
|
|
||||||
boundary.setWidth(width.dup());
|
|
||||||
centerOfMass = Nd4j.create(data.dataType(), D);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private boolean insert(int index) {
|
|
||||||
/*MemoryWorkspace workspace =
|
|
||||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
|
||||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
|
||||||
workspaceConfigurationExternal,
|
|
||||||
workspaceExternal);
|
|
||||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
|
||||||
|
|
||||||
INDArray point = data.slice(index);
|
|
||||||
/*boolean contains = false;
|
|
||||||
SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains);
|
|
||||||
Nd4j.getExecutioner().exec(op);
|
|
||||||
op.getOutputArgument(0).getScalar(0);
|
|
||||||
if (!contains) return false;*/
|
|
||||||
if (!boundary.contains(point))
|
|
||||||
return false;
|
|
||||||
|
|
||||||
|
|
||||||
cumSize++;
|
|
||||||
double mult1 = (double) (cumSize - 1) / (double) cumSize;
|
|
||||||
double mult2 = 1.0 / (double) cumSize;
|
|
||||||
centerOfMass.muli(mult1);
|
|
||||||
centerOfMass.addi(point.mul(mult2));
|
|
||||||
// If there is space in this quad tree and it is a leaf, add the object here
|
|
||||||
if (isLeaf() && size < nodeCapacity) {
|
|
||||||
this.index[size] = index;
|
|
||||||
indices.add(point);
|
|
||||||
size++;
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
INDArray compPoint = data.slice(this.index[i]);
|
|
||||||
if (compPoint.equals(point))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (isLeaf())
|
|
||||||
subDivide();
|
|
||||||
|
|
||||||
|
|
||||||
// Find out where the point can be inserted
|
|
||||||
for (int i = 0; i < numChildren; i++) {
|
|
||||||
if (children[i].insert(index))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
throw new IllegalStateException("Shouldn't reach this state");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Subdivide the node in to
|
|
||||||
* 4 children
|
|
||||||
*/
|
|
||||||
public void subDivide() {
|
|
||||||
/*MemoryWorkspace workspace =
|
|
||||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
|
||||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
|
||||||
workspaceConfigurationExternal,
|
|
||||||
workspaceExternal);
|
|
||||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{
|
|
||||||
|
|
||||||
INDArray newCorner = Nd4j.create(data.dataType(), D);
|
|
||||||
INDArray newWidth = Nd4j.create(data.dataType(), D);
|
|
||||||
for (int i = 0; i < numChildren; i++) {
|
|
||||||
int div = 1;
|
|
||||||
for (int d = 0; d < D; d++) {
|
|
||||||
newWidth.putScalar(d, .5 * boundary.width(d));
|
|
||||||
if ((i / div) % 2 == 1)
|
|
||||||
newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d));
|
|
||||||
else
|
|
||||||
newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d));
|
|
||||||
div *= 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
children[i] = new SpTree(this, data, newCorner, newWidth, indices);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move existing points to correct children
|
|
||||||
for (int i = 0; i < size; i++) {
|
|
||||||
boolean success = false;
|
|
||||||
for (int j = 0; j < this.numChildren; j++)
|
|
||||||
if (!success)
|
|
||||||
success = children[j].insert(index[i]);
|
|
||||||
|
|
||||||
index[i] = -1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Empty parent node
|
|
||||||
size = 0;
|
|
||||||
isLeaf = false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Compute non edge forces using barnes hut
|
|
||||||
* @param pointIndex
|
|
||||||
* @param theta
|
|
||||||
* @param negativeForce
|
|
||||||
* @param sumQ
|
|
||||||
*/
|
|
||||||
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
|
|
||||||
// Make sure that we spend no time on empty nodes or self-interactions
|
|
||||||
INDArray buf = Nd4j.create(data.dataType(), this.D);
|
|
||||||
|
|
||||||
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
|
|
||||||
return;
|
|
||||||
/* MemoryWorkspace workspace =
|
|
||||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
|
||||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
|
||||||
workspaceConfigurationExternal,
|
|
||||||
workspaceExternal);
|
|
||||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
|
||||||
|
|
||||||
// Compute distance between point and center-of-mass
|
|
||||||
data.slice(pointIndex).subi(centerOfMass, buf);
|
|
||||||
|
|
||||||
double D = Nd4j.getBlasWrapper().dot(buf, buf);
|
|
||||||
// Check whether we can use this node as a "summary"
|
|
||||||
double maxWidth = boundary.width().maxNumber().doubleValue();
|
|
||||||
// Check whether we can use this node as a "summary"
|
|
||||||
if (isLeaf() || maxWidth / Math.sqrt(D) < theta) {
|
|
||||||
|
|
||||||
// Compute and add t-SNE force between point and current node
|
|
||||||
double Q = 1.0 / (1.0 + D);
|
|
||||||
double mult = cumSize * Q;
|
|
||||||
sumQ.addAndGet(mult);
|
|
||||||
mult *= Q;
|
|
||||||
negativeForce.addi(buf.mul(mult));
|
|
||||||
} else {
|
|
||||||
|
|
||||||
// Recursively apply Barnes-Hut to children
|
|
||||||
for (int i = 0; i < numChildren; i++) {
|
|
||||||
children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* Compute edge forces using barnes hut
|
|
||||||
* @param rowP a vector
|
|
||||||
* @param colP
|
|
||||||
* @param valP
|
|
||||||
* @param N the number of elements
|
|
||||||
* @param posF the positive force
|
|
||||||
*/
|
|
||||||
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
|
|
||||||
if (!rowP.isVector())
|
|
||||||
throw new IllegalArgumentException("RowP must be a vector");
|
|
||||||
|
|
||||||
// Loop over all edges in the graph
|
|
||||||
// just execute native op
|
|
||||||
Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF));
|
|
||||||
|
|
||||||
/*
|
|
||||||
INDArray buf = Nd4j.create(data.dataType(), this.D);
|
|
||||||
double D;
|
|
||||||
for (int n = 0; n < N; n++) {
|
|
||||||
INDArray slice = data.slice(n);
|
|
||||||
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
|
|
||||||
|
|
||||||
// Compute pairwise distance and Q-value
|
|
||||||
slice.subi(data.slice(colP.getInt(i)), buf);
|
|
||||||
|
|
||||||
D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf);
|
|
||||||
D = valP.getDouble(i) / D;
|
|
||||||
|
|
||||||
// Sum positive force
|
|
||||||
posF.slice(n).addi(buf.muli(D));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return isLeaf;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Verifies the structure of the tree (does bounds checking on each node)
|
|
||||||
* @return true if the structure of the tree
|
|
||||||
* is correct.
|
|
||||||
*/
|
|
||||||
public boolean isCorrect() {
|
|
||||||
/*MemoryWorkspace workspace =
|
|
||||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
|
||||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
|
||||||
workspaceConfigurationExternal,
|
|
||||||
workspaceExternal);
|
|
||||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
|
||||||
|
|
||||||
for (int n = 0; n < size; n++) {
|
|
||||||
INDArray point = data.slice(index[n]);
|
|
||||||
if (!boundary.contains(point))
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!isLeaf()) {
|
|
||||||
boolean correct = true;
|
|
||||||
for (int i = 0; i < numChildren; i++)
|
|
||||||
correct = correct && children[i].isCorrect();
|
|
||||||
return correct;
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The depth of the node
|
|
||||||
* @return the depth of the node
|
|
||||||
*/
|
|
||||||
public int depth() {
|
|
||||||
if (isLeaf())
|
|
||||||
return 1;
|
|
||||||
int depth = 1;
|
|
||||||
int maxChildDepth = 0;
|
|
||||||
for (int i = 0; i < numChildren; i++) {
|
|
||||||
maxChildDepth = Math.max(maxChildDepth, children[0].depth());
|
|
||||||
}
|
|
||||||
|
|
||||||
return depth + maxChildDepth;
|
|
||||||
}
|
|
||||||
|
|
||||||
private void fill(int n) {
|
|
||||||
if (indices.isEmpty() && parent == null)
|
|
||||||
for (int i = 0; i < n; i++) {
|
|
||||||
log.trace("Inserted " + i);
|
|
||||||
insert(i);
|
|
||||||
}
|
|
||||||
else
|
|
||||||
log.warn("Called fill already");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public SpTree[] getChildren() {
|
|
||||||
return children;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getD() {
|
|
||||||
return D;
|
|
||||||
}
|
|
||||||
|
|
||||||
public INDArray getCenterOfMass() {
|
|
||||||
return centerOfMass;
|
|
||||||
}
|
|
||||||
|
|
||||||
public Cell getBoundary() {
|
|
||||||
return boundary;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int[] getIndex() {
|
|
||||||
return index;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getCumSize() {
|
|
||||||
return cumSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCumSize(int cumSize) {
|
|
||||||
this.cumSize = cumSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
public int getNumChildren() {
|
|
||||||
return numChildren;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setNumChildren(int numChildren) {
|
|
||||||
this.numChildren = numChildren;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,117 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.strategy;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
|
||||||
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
|
|
||||||
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
|
|
||||||
@AllArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable {
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected ClusteringStrategyType type;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected Integer initialClusterCount;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected ClusteringAlgorithmCondition optimizationPhaseCondition;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected ClusteringAlgorithmCondition terminationCondition;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected boolean inverse;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected Distance distanceFunction;
|
|
||||||
@Getter(AccessLevel.PUBLIC)
|
|
||||||
@Setter(AccessLevel.PROTECTED)
|
|
||||||
protected boolean allowEmptyClusters;
|
|
||||||
|
|
||||||
public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction,
|
|
||||||
boolean allowEmptyClusters, boolean inverse) {
|
|
||||||
this.type = type;
|
|
||||||
this.initialClusterCount = initialClusterCount;
|
|
||||||
this.distanceFunction = distanceFunction;
|
|
||||||
this.allowEmptyClusters = allowEmptyClusters;
|
|
||||||
this.inverse = inverse;
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount,
|
|
||||||
Distance distanceFunction, boolean inverse) {
|
|
||||||
this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param maxIterationCount
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) {
|
|
||||||
setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param rate
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) {
|
|
||||||
setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate));
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public boolean inverseDistanceCalculation() {
|
|
||||||
return inverse;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param type
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public boolean isStrategyOfType(ClusteringStrategyType type) {
|
|
||||||
return type.equals(this.type);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public Integer getInitialClusterCount() {
|
|
||||||
return initialClusterCount;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,102 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.strategy;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public interface ClusteringStrategy {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean inverseDistanceCalculation();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
ClusteringStrategyType getType();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param type
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isStrategyOfType(ClusteringStrategyType type);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Integer getInitialClusterCount();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
Distance getDistanceFunction();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isAllowEmptyClusters();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
ClusteringAlgorithmCondition getTerminationCondition();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isOptimizationDefined();
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param iterationHistory
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
boolean isOptimizationApplicableNow(IterationHistory iterationHistory);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param maxIterationCount
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param rate
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,25 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.strategy;
|
|
||||||
|
|
||||||
public enum ClusteringStrategyType {
|
|
||||||
FIXED_CLUSTER_COUNT, OPTIMIZATION
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.strategy;
|
|
||||||
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
|
||||||
public class FixedClusterCountStrategy extends BaseClusteringStrategy {
|
|
||||||
|
|
||||||
|
|
||||||
protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction,
|
|
||||||
boolean allowEmptyClusters, boolean inverse) {
|
|
||||||
super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters,
|
|
||||||
inverse);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param clusterCount
|
|
||||||
* @param distanceFunction
|
|
||||||
* @param inverse
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) {
|
|
||||||
return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public boolean inverseDistanceCalculation() {
|
|
||||||
return inverse;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isOptimizationDefined() {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.strategy;
|
|
||||||
|
|
||||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
|
||||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
|
||||||
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
|
|
||||||
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
|
|
||||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
|
||||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimization;
|
|
||||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
|
|
||||||
|
|
||||||
public class OptimisationStrategy extends BaseClusteringStrategy {
|
|
||||||
public static int defaultIterationCount = 100;
|
|
||||||
|
|
||||||
private ClusteringOptimization clusteringOptimisation;
|
|
||||||
private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition;
|
|
||||||
|
|
||||||
protected OptimisationStrategy() {
|
|
||||||
super();
|
|
||||||
}
|
|
||||||
|
|
||||||
protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) {
|
|
||||||
super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) {
|
|
||||||
return new OptimisationStrategy(initialClusterCount, distanceFunction);
|
|
||||||
}
|
|
||||||
|
|
||||||
public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) {
|
|
||||||
clusteringOptimisation = new ClusteringOptimization(type, value);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) {
|
|
||||||
clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) {
|
|
||||||
clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate);
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public double getClusteringOptimizationValue() {
|
|
||||||
return clusteringOptimisation.getValue();
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isClusteringOptimizationType(ClusteringOptimizationType type) {
|
|
||||||
return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isOptimizationDefined() {
|
|
||||||
return clusteringOptimisation != null;
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
|
|
||||||
return clusteringOptimisationApplicationCondition != null
|
|
||||||
&& clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,74 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.util;
|
|
||||||
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.concurrent.*;
|
|
||||||
|
|
||||||
public class MultiThreadUtils {
|
|
||||||
|
|
||||||
private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class);
|
|
||||||
|
|
||||||
private static ExecutorService instance;
|
|
||||||
|
|
||||||
private MultiThreadUtils() {}
|
|
||||||
|
|
||||||
public static synchronized ExecutorService newExecutorService() {
|
|
||||||
int nThreads = Runtime.getRuntime().availableProcessors();
|
|
||||||
return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(),
|
|
||||||
new ThreadFactory() {
|
|
||||||
@Override
|
|
||||||
public Thread newThread(Runnable r) {
|
|
||||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
|
||||||
t.setDaemon(true);
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) {
|
|
||||||
int tasksCount = tasks.size();
|
|
||||||
final CountDownLatch latch = new CountDownLatch(tasksCount);
|
|
||||||
for (int i = 0; i < tasksCount; i++) {
|
|
||||||
final int taskIdx = i;
|
|
||||||
executorService.execute(new Runnable() {
|
|
||||||
public void run() {
|
|
||||||
try {
|
|
||||||
tasks.get(taskIdx).run();
|
|
||||||
} catch (Throwable e) {
|
|
||||||
log.info("Unchecked exception thrown by task", e);
|
|
||||||
} finally {
|
|
||||||
latch.countDown();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
latch.await();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,61 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.util;
|
|
||||||
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
public class SetUtils {
|
|
||||||
private SetUtils() {}
|
|
||||||
|
|
||||||
// Set specific operations
|
|
||||||
|
|
||||||
public static <T> Set<T> intersection(Collection<T> parentCollection, Collection<T> removeFromCollection) {
|
|
||||||
Set<T> results = new HashSet<>(parentCollection);
|
|
||||||
results.retainAll(removeFromCollection);
|
|
||||||
return results;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> boolean intersectionP(Set<? extends T> s1, Set<? extends T> s2) {
|
|
||||||
for (T elt : s1) {
|
|
||||||
if (s2.contains(elt))
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static <T> Set<T> union(Set<? extends T> s1, Set<? extends T> s2) {
|
|
||||||
Set<T> s3 = new HashSet<>(s1);
|
|
||||||
s3.addAll(s2);
|
|
||||||
return s3;
|
|
||||||
}
|
|
||||||
|
|
||||||
/** Return is s1 \ s2 */
|
|
||||||
|
|
||||||
public static <T> Set<T> difference(Collection<? extends T> s1, Collection<? extends T> s2) {
|
|
||||||
Set<T> s3 = new HashSet<>(s1);
|
|
||||||
s3.removeAll(s2);
|
|
||||||
return s3;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
|
@ -1,633 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.vptree;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.deeplearning4j.clustering.sptree.HeapObject;
|
|
||||||
import org.deeplearning4j.clustering.util.MathUtils;
|
|
||||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
|
||||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
|
||||||
import org.nd4j.linalg.api.memory.enums.*;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
|
||||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.*;
|
|
||||||
import java.util.concurrent.*;
|
|
||||||
import java.util.concurrent.atomic.AtomicInteger;
|
|
||||||
|
|
||||||
@Slf4j
|
|
||||||
@Builder
|
|
||||||
@AllArgsConstructor
|
|
||||||
public class VPTree implements Serializable {
|
|
||||||
private static final long serialVersionUID = 1L;
|
|
||||||
|
|
||||||
public static final String EUCLIDEAN = "euclidean";
|
|
||||||
private double tau;
|
|
||||||
@Getter
|
|
||||||
@Setter
|
|
||||||
private INDArray items;
|
|
||||||
private List<INDArray> itemsList;
|
|
||||||
private Node root;
|
|
||||||
private String similarityFunction;
|
|
||||||
@Getter
|
|
||||||
private boolean invert = false;
|
|
||||||
private transient ExecutorService executorService;
|
|
||||||
@Getter
|
|
||||||
private int workers = 1;
|
|
||||||
private AtomicInteger size = new AtomicInteger(0);
|
|
||||||
|
|
||||||
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
|
|
||||||
|
|
||||||
private WorkspaceConfiguration workspaceConfiguration;
|
|
||||||
|
|
||||||
protected VPTree() {
|
|
||||||
// method for serialization only
|
|
||||||
scalars = new ThreadLocal<>();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param points
|
|
||||||
* @param invert
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray points, boolean invert) {
|
|
||||||
this(points, "euclidean", 1, invert);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param points
|
|
||||||
* @param invert
|
|
||||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray points, boolean invert, int workers) {
|
|
||||||
this(points, "euclidean", workers, invert);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items the items to use
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @param invert whether to invert the distance (similarity functions have different min/max objectives)
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray items, String similarityFunction, boolean invert) {
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
this.invert = invert;
|
|
||||||
this.items = items;
|
|
||||||
root = buildFromPoints(items);
|
|
||||||
workers = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items the items to use
|
|
||||||
* @param similarityFunction the similarity function to use
|
|
||||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
|
||||||
* @param invert whether to invert the metric (different optimization objective)
|
|
||||||
*/
|
|
||||||
public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) {
|
|
||||||
this.workers = workers;
|
|
||||||
|
|
||||||
val list = new INDArray[items.size()];
|
|
||||||
|
|
||||||
// build list of INDArrays first
|
|
||||||
for (int i = 0; i < items.size(); i++)
|
|
||||||
list[i] = items.get(i).getPoint();
|
|
||||||
//this.items.putRow(i, items.get(i).getPoint());
|
|
||||||
|
|
||||||
// just stack them out with concat :)
|
|
||||||
this.items = Nd4j.pile(list);
|
|
||||||
|
|
||||||
this.invert = invert;
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
root = buildFromPoints(this.items);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items
|
|
||||||
* @param similarityFunction
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray items, String similarityFunction) {
|
|
||||||
this(items, similarityFunction, 1, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items
|
|
||||||
* @param similarityFunction
|
|
||||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
|
||||||
* @param invert
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) {
|
|
||||||
this.similarityFunction = similarityFunction;
|
|
||||||
this.invert = invert;
|
|
||||||
this.items = items;
|
|
||||||
|
|
||||||
this.workers = workers;
|
|
||||||
root = buildFromPoints(items);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items
|
|
||||||
* @param similarityFunction
|
|
||||||
*/
|
|
||||||
public VPTree(List<DataPoint> items, String similarityFunction) {
|
|
||||||
this(items, similarityFunction, 1, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items
|
|
||||||
*/
|
|
||||||
public VPTree(INDArray items) {
|
|
||||||
this(items, EUCLIDEAN);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param items
|
|
||||||
*/
|
|
||||||
public VPTree(List<DataPoint> items) {
|
|
||||||
this(items, EUCLIDEAN);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create an ndarray
|
|
||||||
* from the datapoints
|
|
||||||
* @param data
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
public static INDArray buildFromData(List<DataPoint> data) {
|
|
||||||
INDArray ret = Nd4j.create(data.size(), data.get(0).getD());
|
|
||||||
for (int i = 0; i < ret.slices(); i++)
|
|
||||||
ret.putSlice(i, data.get(i).getPoint());
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param basePoint
|
|
||||||
* @param distancesArr
|
|
||||||
*/
|
|
||||||
public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) {
|
|
||||||
switch (similarityFunction) {
|
|
||||||
case "euclidean":
|
|
||||||
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1));
|
|
||||||
break;
|
|
||||||
case "cosinedistance":
|
|
||||||
Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
case "cosinesimilarity":
|
|
||||||
Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
case "manhattan":
|
|
||||||
Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
case "dot":
|
|
||||||
Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1));
|
|
||||||
break;
|
|
||||||
case "jaccard":
|
|
||||||
Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
case "hamming":
|
|
||||||
Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1));
|
|
||||||
break;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if (invert)
|
|
||||||
distancesArr.negi();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) {
|
|
||||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Euclidean distance
|
|
||||||
* @return the distance between the two points
|
|
||||||
*/
|
|
||||||
public double distance(INDArray arr1, INDArray arr2) {
|
|
||||||
if (scalars == null)
|
|
||||||
scalars = new ThreadLocal<>();
|
|
||||||
|
|
||||||
if (scalars.get() == null)
|
|
||||||
scalars.set(Nd4j.scalar(arr1.dataType(), 0.0));
|
|
||||||
|
|
||||||
switch (similarityFunction) {
|
|
||||||
case "jaccard":
|
|
||||||
double ret7 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new JaccardDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret7 : ret7;
|
|
||||||
case "hamming":
|
|
||||||
double ret8 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new HammingDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret8 : ret8;
|
|
||||||
case "euclidean":
|
|
||||||
double ret = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret : ret;
|
|
||||||
case "cosinesimilarity":
|
|
||||||
double ret2 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret2 : ret2;
|
|
||||||
case "cosinedistance":
|
|
||||||
double ret6 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new CosineDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret6 : ret6;
|
|
||||||
case "manhattan":
|
|
||||||
double ret3 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret3 : ret3;
|
|
||||||
case "dot":
|
|
||||||
double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2);
|
|
||||||
return invert ? -dotRet : dotRet;
|
|
||||||
default:
|
|
||||||
double ret4 = Nd4j.getExecutioner()
|
|
||||||
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
|
|
||||||
.getFinalResult().doubleValue();
|
|
||||||
return invert ? -ret4 : ret4;
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected class NodeBuilder implements Callable<Node> {
|
|
||||||
protected List<INDArray> list;
|
|
||||||
protected List<Integer> indices;
|
|
||||||
|
|
||||||
public NodeBuilder(List<INDArray> list, List<Integer> indices) {
|
|
||||||
this.list = list;
|
|
||||||
this.indices = indices;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Node call() throws Exception {
|
|
||||||
return buildFromPoints(list, indices);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node buildFromPoints(List<INDArray> points, List<Integer> indices) {
|
|
||||||
Node ret = new Node(0, 0);
|
|
||||||
|
|
||||||
|
|
||||||
// nothing to sort here
|
|
||||||
if (points.size() == 1) {
|
|
||||||
ret.point = points.get(0);
|
|
||||||
ret.index = indices.get(0);
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
// opening workspace, and creating it if that's the first call
|
|
||||||
/* MemoryWorkspace workspace =
|
|
||||||
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
|
|
||||||
|
|
||||||
INDArray items = Nd4j.vstack(points);
|
|
||||||
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
|
|
||||||
INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint);
|
|
||||||
ret.point = basePoint;
|
|
||||||
ret.index = indices.get(randomPoint);
|
|
||||||
INDArray distancesArr = Nd4j.create(items.rows(), 1);
|
|
||||||
|
|
||||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
|
||||||
|
|
||||||
double medianDistance = distancesArr.medianNumber().doubleValue();
|
|
||||||
|
|
||||||
ret.threshold = (float) medianDistance;
|
|
||||||
|
|
||||||
List<INDArray> leftPoints = new ArrayList<>();
|
|
||||||
List<Integer> leftIndices = new ArrayList<>();
|
|
||||||
List<INDArray> rightPoints = new ArrayList<>();
|
|
||||||
List<Integer> rightIndices = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < distancesArr.length(); i++) {
|
|
||||||
if (i == randomPoint)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (distancesArr.getDouble(i) < medianDistance) {
|
|
||||||
leftPoints.add(points.get(i));
|
|
||||||
leftIndices.add(indices.get(i));
|
|
||||||
} else {
|
|
||||||
rightPoints.add(points.get(i));
|
|
||||||
rightIndices.add(indices.get(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// closing workspace
|
|
||||||
//workspace.notifyScopeLeft();
|
|
||||||
//log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes());
|
|
||||||
|
|
||||||
if (workers > 1) {
|
|
||||||
if (!leftPoints.isEmpty())
|
|
||||||
ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints);
|
|
||||||
|
|
||||||
if (!rightPoints.isEmpty())
|
|
||||||
ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices));
|
|
||||||
} else {
|
|
||||||
if (!leftPoints.isEmpty())
|
|
||||||
ret.left = buildFromPoints(leftPoints, leftIndices);
|
|
||||||
|
|
||||||
if (!rightPoints.isEmpty())
|
|
||||||
ret.right = buildFromPoints(rightPoints, rightIndices);
|
|
||||||
}
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
private Node buildFromPoints(INDArray items) {
|
|
||||||
if (executorService == null && items == this.items && workers > 1) {
|
|
||||||
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
|
||||||
|
|
||||||
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
|
|
||||||
@Override
|
|
||||||
public Thread newThread(final Runnable r) {
|
|
||||||
Thread t = new Thread(new Runnable() {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
|
||||||
r.run();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
t.setDaemon(true);
|
|
||||||
t.setName("VPTree thread");
|
|
||||||
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
final Node ret = new Node(0, 0);
|
|
||||||
size.incrementAndGet();
|
|
||||||
|
|
||||||
/*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
|
|
||||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
|
|
||||||
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
|
|
||||||
.policySpill(SpillPolicy.REALLOCATE).build();
|
|
||||||
|
|
||||||
// opening workspace
|
|
||||||
MemoryWorkspace workspace =
|
|
||||||
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
|
|
||||||
|
|
||||||
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
|
|
||||||
INDArray basePoint = items.getRow(randomPoint, true);
|
|
||||||
INDArray distancesArr = Nd4j.create(items.rows(), 1);
|
|
||||||
ret.point = basePoint;
|
|
||||||
ret.index = randomPoint;
|
|
||||||
|
|
||||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
|
||||||
|
|
||||||
double medianDistance = distancesArr.medianNumber().doubleValue();
|
|
||||||
|
|
||||||
ret.threshold = (float) medianDistance;
|
|
||||||
|
|
||||||
List<INDArray> leftPoints = new ArrayList<>();
|
|
||||||
List<Integer> leftIndices = new ArrayList<>();
|
|
||||||
List<INDArray> rightPoints = new ArrayList<>();
|
|
||||||
List<Integer> rightIndices = new ArrayList<>();
|
|
||||||
|
|
||||||
for (int i = 0; i < distancesArr.length(); i++) {
|
|
||||||
if (i == randomPoint)
|
|
||||||
continue;
|
|
||||||
|
|
||||||
if (distancesArr.getDouble(i) < medianDistance) {
|
|
||||||
leftPoints.add(items.getRow(i, true));
|
|
||||||
leftIndices.add(i);
|
|
||||||
} else {
|
|
||||||
rightPoints.add(items.getRow(i, true));
|
|
||||||
rightIndices.add(i);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// closing workspace
|
|
||||||
//workspace.notifyScopeLeft();
|
|
||||||
//workspace.destroyWorkspace(true);
|
|
||||||
|
|
||||||
if (!leftPoints.isEmpty())
|
|
||||||
ret.left = buildFromPoints(leftPoints, leftIndices);
|
|
||||||
|
|
||||||
if (!rightPoints.isEmpty())
|
|
||||||
ret.right = buildFromPoints(rightPoints, rightIndices);
|
|
||||||
|
|
||||||
// destroy once again
|
|
||||||
//workspace.destroyWorkspace(true);
|
|
||||||
|
|
||||||
if (ret.left != null)
|
|
||||||
ret.left.fetchFutures();
|
|
||||||
|
|
||||||
if (ret.right != null)
|
|
||||||
ret.right.fetchFutures();
|
|
||||||
|
|
||||||
if (executorService != null)
|
|
||||||
executorService.shutdown();
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) {
|
|
||||||
search(target, k, results, distances, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
|
|
||||||
boolean filterEqual) {
|
|
||||||
search(target, k, results, distances, filterEqual, false);
|
|
||||||
}
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param target
|
|
||||||
* @param k
|
|
||||||
* @param results
|
|
||||||
* @param distances
|
|
||||||
*/
|
|
||||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
|
|
||||||
boolean filterEqual, boolean dropEdge) {
|
|
||||||
if (items != null)
|
|
||||||
if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1)
|
|
||||||
throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", "
|
|
||||||
+ items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead");
|
|
||||||
|
|
||||||
k = Math.min(k, items.rows());
|
|
||||||
results.clear();
|
|
||||||
distances.clear();
|
|
||||||
|
|
||||||
PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator());
|
|
||||||
|
|
||||||
search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE);
|
|
||||||
|
|
||||||
while (!pq.isEmpty()) {
|
|
||||||
HeapObject ho = pq.peek();
|
|
||||||
results.add(new DataPoint(ho.getIndex(), ho.getPoint()));
|
|
||||||
distances.add(ho.getDistance());
|
|
||||||
pq.poll();
|
|
||||||
}
|
|
||||||
|
|
||||||
Collections.reverse(results);
|
|
||||||
Collections.reverse(distances);
|
|
||||||
|
|
||||||
if (dropEdge || results.size() > k) {
|
|
||||||
if (filterEqual && distances.get(0) == 0.0) {
|
|
||||||
results.remove(0);
|
|
||||||
distances.remove(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
while (results.size() > k) {
|
|
||||||
results.remove(results.size() - 1);
|
|
||||||
distances.remove(distances.size() - 1);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param node
|
|
||||||
* @param target
|
|
||||||
* @param k
|
|
||||||
* @param pq
|
|
||||||
*/
|
|
||||||
public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) {
|
|
||||||
|
|
||||||
if (node == null)
|
|
||||||
return;
|
|
||||||
|
|
||||||
double tau = cTau;
|
|
||||||
|
|
||||||
INDArray get = node.getPoint(); //items.getRow(node.getIndex());
|
|
||||||
double distance = distance(get, target);
|
|
||||||
if (distance < tau) {
|
|
||||||
if (pq.size() == k)
|
|
||||||
pq.poll();
|
|
||||||
|
|
||||||
pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance));
|
|
||||||
if (pq.size() == k)
|
|
||||||
tau = pq.peek().getDistance();
|
|
||||||
}
|
|
||||||
|
|
||||||
Node left = node.getLeft();
|
|
||||||
Node right = node.getRight();
|
|
||||||
|
|
||||||
if (left == null && right == null)
|
|
||||||
return;
|
|
||||||
|
|
||||||
if (distance < node.getThreshold()) {
|
|
||||||
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first
|
|
||||||
search(left, target, k, pq, tau);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child
|
|
||||||
search(right, target, k, pq, tau);
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first
|
|
||||||
search(right, target, k, pq, tau);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child
|
|
||||||
search(left, target, k, pq, tau);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
protected class HeapObjectComparator implements Comparator<HeapObject> {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int compare(HeapObject o1, HeapObject o2) {
|
|
||||||
return Double.compare(o2.getDistance(), o1.getDistance());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class Node implements Serializable {
|
|
||||||
private static final long serialVersionUID = 2L;
|
|
||||||
|
|
||||||
private int index;
|
|
||||||
private float threshold;
|
|
||||||
private Node left, right;
|
|
||||||
private INDArray point;
|
|
||||||
protected transient Future<Node> futureLeft;
|
|
||||||
protected transient Future<Node> futureRight;
|
|
||||||
|
|
||||||
public Node(int index, float threshold) {
|
|
||||||
this.index = index;
|
|
||||||
this.threshold = threshold;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void fetchFutures() {
|
|
||||||
try {
|
|
||||||
if (futureLeft != null) {
|
|
||||||
/*while (!futureLeft.isDone())
|
|
||||||
Thread.sleep(100);*/
|
|
||||||
|
|
||||||
|
|
||||||
left = futureLeft.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (futureRight != null) {
|
|
||||||
/*while (!futureRight.isDone())
|
|
||||||
Thread.sleep(100);*/
|
|
||||||
|
|
||||||
right = futureRight.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if (left != null)
|
|
||||||
left.fetchFutures();
|
|
||||||
|
|
||||||
if (right != null)
|
|
||||||
right.fetchFutures();
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,79 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.vptree;
|
|
||||||
|
|
||||||
import lombok.Getter;
|
|
||||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class VPTreeFillSearch {
|
|
||||||
private VPTree vpTree;
|
|
||||||
private int k;
|
|
||||||
@Getter
|
|
||||||
private List<DataPoint> results;
|
|
||||||
@Getter
|
|
||||||
private List<Double> distances;
|
|
||||||
private INDArray target;
|
|
||||||
|
|
||||||
public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) {
|
|
||||||
this.vpTree = vpTree;
|
|
||||||
this.k = k;
|
|
||||||
this.target = target;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void search() {
|
|
||||||
results = new ArrayList<>();
|
|
||||||
distances = new ArrayList<>();
|
|
||||||
//initial search
|
|
||||||
//vpTree.search(target,k,results,distances);
|
|
||||||
|
|
||||||
//fill till there is k results
|
|
||||||
//by going down the list
|
|
||||||
// if(results.size() < k) {
|
|
||||||
INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1);
|
|
||||||
vpTree.calcDistancesRelativeTo(target, distancesArr);
|
|
||||||
INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert());
|
|
||||||
results.clear();
|
|
||||||
distances.clear();
|
|
||||||
if (vpTree.getItems().isVector()) {
|
|
||||||
for (int i = 0; i < k; i++) {
|
|
||||||
int idx = sortWithIndices[0].getInt(i);
|
|
||||||
results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx))));
|
|
||||||
distances.add(sortWithIndices[1].getDouble(idx));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for (int i = 0; i < k; i++) {
|
|
||||||
int idx = sortWithIndices[0].getInt(i);
|
|
||||||
results.add(new DataPoint(idx, vpTree.getItems().getRow(idx)));
|
|
||||||
//distances.add(sortWithIndices[1].getDouble(idx));
|
|
||||||
distances.add(sortWithIndices[1].getDouble(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,21 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.vptree;
|
|
|
@ -1,46 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.cluster;
|
|
||||||
|
|
||||||
import org.junit.Assert;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class ClusterSetTest {
|
|
||||||
@Test
|
|
||||||
public void testGetMostPopulatedClusters() {
|
|
||||||
ClusterSet clusterSet = new ClusterSet(false);
|
|
||||||
List<Cluster> clusters = new ArrayList<>();
|
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
Cluster cluster = new Cluster();
|
|
||||||
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
|
|
||||||
clusters.add(cluster);
|
|
||||||
}
|
|
||||||
clusterSet.setClusters(clusters);
|
|
||||||
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
|
|
||||||
for (int i = 0; i < 5; i++) {
|
|
||||||
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,422 +0,0 @@
|
||||||
/*
|
|
||||||
* ******************************************************************************
|
|
||||||
* *
|
|
||||||
* *
|
|
||||||
* * This program and the accompanying materials are made available under the
|
|
||||||
* * terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
* *
|
|
||||||
* * See the NOTICE file distributed with this work for additional
|
|
||||||
* * information regarding copyright ownership.
|
|
||||||
* * Unless required by applicable law or agreed to in writing, software
|
|
||||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* * License for the specific language governing permissions and limitations
|
|
||||||
* * under the License.
|
|
||||||
* *
|
|
||||||
* * SPDX-License-Identifier: Apache-2.0
|
|
||||||
* *****************************************************************************
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.deeplearning4j.clustering.kdtree;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
|
||||||
import org.joda.time.Duration;
|
|
||||||
import org.junit.Before;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.Ignore;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.buffer.DataType;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
import org.nd4j.common.primitives.Pair;
|
|
||||||
import org.nd4j.shade.guava.base.Stopwatch;
|
|
||||||
import org.nd4j.shade.guava.primitives.Doubles;
|
|
||||||
import org.nd4j.shade.guava.primitives.Floats;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
|
||||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
import static org.junit.Assert.assertTrue;
|
|
||||||
|
|
||||||
public class KDTreeTest extends BaseDL4JTest {
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public long getTimeoutMilliseconds() {
|
|
||||||
return 120000L;
|
|
||||||
}
|
|
||||||
|
|
||||||
private KDTree kdTree;
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void beforeClass(){
|
|
||||||
Nd4j.setDataType(DataType.FLOAT);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Before
|
|
||||||
public void setUp() {
|
|
||||||
kdTree = new KDTree(2);
|
|
||||||
float[] data = new float[]{7,2};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{5,4};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{2,3};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{4,7};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{9,6};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{8,1};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testTree() {
|
|
||||||
KDTree tree = new KDTree(2);
|
|
||||||
INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT);
|
|
||||||
INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT);
|
|
||||||
tree.insert(half);
|
|
||||||
tree.insert(one);
|
|
||||||
Pair<Double, INDArray> pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT));
|
|
||||||
assertEquals(half, pair.getValue());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testInsert() {
|
|
||||||
int elements = 10;
|
|
||||||
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
|
|
||||||
|
|
||||||
KDTree kdTree = new KDTree(digits.size());
|
|
||||||
List<List<Double>> lists = new ArrayList<>();
|
|
||||||
for (int i = 0; i < elements; i++) {
|
|
||||||
List<Double> thisList = new ArrayList<>(digits.size());
|
|
||||||
for (int k = 0; k < digits.size(); k++) {
|
|
||||||
thisList.add(digits.get(k) + i);
|
|
||||||
}
|
|
||||||
lists.add(thisList);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < elements; i++) {
|
|
||||||
double[] features = Doubles.toArray(lists.get(i));
|
|
||||||
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
|
|
||||||
kdTree.insert(ind);
|
|
||||||
assertEquals(i + 1, kdTree.size());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDelete() {
|
|
||||||
int elements = 10;
|
|
||||||
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
|
|
||||||
|
|
||||||
KDTree kdTree = new KDTree(digits.size());
|
|
||||||
List<List<Double>> lists = new ArrayList<>();
|
|
||||||
for (int i = 0; i < elements; i++) {
|
|
||||||
List<Double> thisList = new ArrayList<>(digits.size());
|
|
||||||
for (int k = 0; k < digits.size(); k++) {
|
|
||||||
thisList.add(digits.get(k) + i);
|
|
||||||
}
|
|
||||||
lists.add(thisList);
|
|
||||||
}
|
|
||||||
|
|
||||||
INDArray toDelete = Nd4j.empty(DataType.DOUBLE),
|
|
||||||
leafToDelete = Nd4j.empty(DataType.DOUBLE);
|
|
||||||
for (int i = 0; i < elements; i++) {
|
|
||||||
double[] features = Doubles.toArray(lists.get(i));
|
|
||||||
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
|
|
||||||
if (i == 1)
|
|
||||||
toDelete = ind;
|
|
||||||
if (i == elements - 1) {
|
|
||||||
leafToDelete = ind;
|
|
||||||
}
|
|
||||||
kdTree.insert(ind);
|
|
||||||
assertEquals(i + 1, kdTree.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
kdTree.delete(toDelete);
|
|
||||||
assertEquals(9, kdTree.size());
|
|
||||||
kdTree.delete(leafToDelete);
|
|
||||||
assertEquals(8, kdTree.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNN() {
|
|
||||||
int n = 10;
|
|
||||||
|
|
||||||
// make a KD-tree of dimension {#n}
|
|
||||||
KDTree kdTree = new KDTree(n);
|
|
||||||
for (int i = -1; i < n; i++) {
|
|
||||||
// Insert a unit vector along each dimension
|
|
||||||
List<Double> vec = new ArrayList<>(n);
|
|
||||||
// i = -1 ensures the origin is in the Tree
|
|
||||||
for (int k = 0; k < n; k++) {
|
|
||||||
vec.add((k == i) ? 1.0 : 0.0);
|
|
||||||
}
|
|
||||||
INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT);
|
|
||||||
kdTree.insert(indVec);
|
|
||||||
}
|
|
||||||
Random rand = new Random();
|
|
||||||
|
|
||||||
// random point in the Hypercube
|
|
||||||
List<Double> pt = new ArrayList(n);
|
|
||||||
for (int k = 0; k < n; k++) {
|
|
||||||
pt.add(rand.nextDouble());
|
|
||||||
}
|
|
||||||
Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT));
|
|
||||||
|
|
||||||
// Always true for points in the unitary hypercube
|
|
||||||
assertTrue(result.getKey() < Double.MAX_VALUE);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN() {
|
|
||||||
int dimensions = 512;
|
|
||||||
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
|
|
||||||
// make a KD-tree of dimension {#dimensions}
|
|
||||||
Stopwatch stopwatch = Stopwatch.createStarted();
|
|
||||||
KDTree kdTree = new KDTree(dimensions);
|
|
||||||
for (int i = -1; i < vectorsNo; i++) {
|
|
||||||
// Insert a unit vector along each dimension
|
|
||||||
INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
|
|
||||||
kdTree.insert(indVec);
|
|
||||||
}
|
|
||||||
stopwatch.stop();
|
|
||||||
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
|
|
||||||
|
|
||||||
Random rand = new Random();
|
|
||||||
// random point in the Hypercube
|
|
||||||
List<Double> pt = new ArrayList(dimensions);
|
|
||||||
for (int k = 0; k < dimensions; k++) {
|
|
||||||
pt.add(rand.nextFloat() * 10.0);
|
|
||||||
}
|
|
||||||
stopwatch.reset();
|
|
||||||
stopwatch.start();
|
|
||||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
|
|
||||||
stopwatch.stop();
|
|
||||||
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_Simple() {
|
|
||||||
int n = 2;
|
|
||||||
KDTree kdTree = new KDTree(n);
|
|
||||||
|
|
||||||
float[] data = new float[]{3,3};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{1,1};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
data = new float[]{2,2};
|
|
||||||
kdTree.insert(Nd4j.createFromArray(data));
|
|
||||||
|
|
||||||
data = new float[]{0,0};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
|
|
||||||
|
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
|
||||||
|
|
||||||
assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
|
||||||
|
|
||||||
assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_1() {
|
|
||||||
|
|
||||||
assertEquals(6, kdTree.size());
|
|
||||||
|
|
||||||
float[] data = new float[]{8,1};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
|
||||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_2() {
|
|
||||||
float[] data = new float[]{8, 1};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
|
||||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_3() {
|
|
||||||
|
|
||||||
float[] data = new float[]{2, 3};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
|
||||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_4() {
|
|
||||||
float[] data = new float[]{2, 3};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
|
||||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testKNN_5() {
|
|
||||||
float[] data = new float[]{2, 3};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
|
||||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_KNN_6() {
|
|
||||||
float[] data = new float[]{4, 6};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_KNN_7() {
|
|
||||||
float[] data = new float[]{4, 6};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void test_KNN_8() {
|
|
||||||
float[] data = new float[]{4, 6};
|
|
||||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
|
||||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
|
||||||
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
|
||||||
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testNoDuplicates() {
|
|
||||||
int N = 100;
|
|
||||||
KDTree bigTree = new KDTree(2);
|
|
||||||
|
|
||||||
List<INDArray> points = new ArrayList<>();
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
double[] data = new double[]{i, i};
|
|
||||||
points.add(Nd4j.createFromArray(data));
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
bigTree.insert(points.get(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(N, bigTree.size());
|
|
||||||
|
|
||||||
INDArray node = Nd4j.empty(DataType.DOUBLE);
|
|
||||||
for (int i = 0; i < N; ++i) {
|
|
||||||
node = bigTree.delete(node.isEmpty() ? points.get(i) : node);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(0, bigTree.size());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Ignore
|
|
||||||
@Test
|
|
||||||
public void performanceTest() {
|
|
||||||
int n = 2;
|
|
||||||
int num = 100000;
|
|
||||||
// make a KD-tree of dimension {#n}
|
|
||||||
long start = System.currentTimeMillis();
|
|
||||||
KDTree kdTree = new KDTree(n);
|
|
||||||
INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n);
|
|
||||||
for (int i = 0 ; i < num; ++i) {
|
|
||||||
kdTree.insert(inputArrray.getRow(i));
|
|
||||||
}
|
|
||||||
|
|
||||||
long end = System.currentTimeMillis();
|
|
||||||
Duration duration = new Duration(start, end);
|
|
||||||
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
|
|
||||||
|
|
||||||
List<Float> pt = new ArrayList(num);
|
|
||||||
for (int k = 0; k < n; k++) {
|
|
||||||
pt.add((float)(num / 2));
|
|
||||||
}
|
|
||||||
start = System.currentTimeMillis();
|
|
||||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
|
|
||||||
end = System.currentTimeMillis();
|
|
||||||
duration = new Duration(start, end);
|
|
||||||
long elapsed = end - start;
|
|
||||||
System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis());
|
|
||||||
for (val pair : list) {
|
|
||||||
System.out.println(pair.getFirst() + " " + pair.getSecond()) ;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue