Remove more unused modules
parent
fa8537f0c7
commit
ee06fdd16f
|
@ -1,64 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-spark-inference-client</artifactId>
|
||||
|
||||
<name>datavec-spark-inference-client</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-model</artifactId>
|
||||
<version>${project.parent.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,292 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.client;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.http.exceptions.UnirestException;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.model.*;
|
||||
import org.datavec.spark.inference.model.service.DataVecTransformService;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Slf4j
|
||||
public class DataVecTransformClient implements DataVecTransformService {
|
||||
private String url;
|
||||
|
||||
static {
|
||||
// Only one time
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transformProcess
|
||||
*/
|
||||
@Override
|
||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
||||
try {
|
||||
String s = transformProcess.toJson();
|
||||
Unirest.post(url + "/transformprocess").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(s).asJson();
|
||||
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in setCSVTransformProcess()", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
/**
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public TransformProcess getCSVTransformProcess() {
|
||||
try {
|
||||
String s = Unirest.get(url + "/transformprocess").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").asString().getBody();
|
||||
return TransformProcess.fromJson(s);
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in getCSVTransformProcess()",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageTransformProcess getImageTransformProcess() {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
|
||||
try {
|
||||
SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
|
||||
.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.body(transform).asObject(SingleCSVRecord.class).getBody();
|
||||
return singleCsvRecord;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformIncremental(SingleCSVRecord)",e);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"TRUE")
|
||||
.body(batchCSVRecord)
|
||||
.asObject(SequenceBatchCSVRecord.class)
|
||||
.getBody();
|
||||
return batchCSVRecord1;
|
||||
} catch (UnirestException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"FALSE")
|
||||
.body(batchCSVRecord)
|
||||
.asObject(BatchCSVRecord.class)
|
||||
.getBody();
|
||||
return batchCSVRecord1;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transform(BatchCSVRecord)", e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(batchCSVRecord)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
return batchArray1;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformArray(BatchCSVRecord)",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
||||
try {
|
||||
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
|
||||
return array;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
/**
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
||||
try {
|
||||
Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
|
||||
.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
||||
.body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
|
||||
return array;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformSequenceArrayIncremental",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
||||
.body(batchCSVRecord)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
return batchArray1;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformSequenceArray",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform")
|
||||
.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
||||
.body(batchCSVRecord)
|
||||
.asObject(SequenceBatchCSVRecord.class).getBody();
|
||||
return batchCSVRecord1;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformSequence");
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
||||
try {
|
||||
SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
|
||||
.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json")
|
||||
.header(SEQUENCE_OR_NOT_HEADER,"true")
|
||||
.body(transform).asObject(SequenceBatchCSVRecord.class).getBody();
|
||||
return singleCsvRecord;
|
||||
} catch (UnirestException e) {
|
||||
log.error("Error in transformSequenceIncremental");
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
|
@ -1,45 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.datavec.transform.client;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.datavec.transform.client";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseND4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,139 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.transform.client;
|
||||
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
||||
import org.datavec.spark.inference.client.DataVecTransformClient;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.ServerSocket;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assume.assumeNotNull;
|
||||
|
||||
public class DataVecTransformClientTest {
|
||||
private static CSVSparkTransformServer server;
|
||||
private static int port = getAvailablePort();
|
||||
private static DataVecTransformClient client;
|
||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
private static TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass() throws Exception {
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
fileSave.deleteOnExit();
|
||||
server = new CSVSparkTransformServer();
|
||||
server.runMain(new String[] {"-dp", String.valueOf(port)});
|
||||
|
||||
client = new DataVecTransformClient("http://localhost:" + port);
|
||||
client.setCSVTransformProcess(transformProcess);
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void afterClass() throws Exception {
|
||||
server.stop();
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testSequenceClient() {
|
||||
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
|
||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
||||
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
|
||||
List<BatchCSVRecord> batchCSVRecordList = new ArrayList<>();
|
||||
for(int i = 0; i < 5; i++) {
|
||||
batchCSVRecordList.add(batchCSVRecord);
|
||||
}
|
||||
|
||||
sequenceBatchCSVRecord.add(batchCSVRecordList);
|
||||
|
||||
SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord);
|
||||
assumeNotNull(sequenceBatchCSVRecord1);
|
||||
|
||||
Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord);
|
||||
assumeNotNull(array);
|
||||
|
||||
Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord);
|
||||
assumeNotNull(incrementalBody);
|
||||
|
||||
Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord);
|
||||
assumeNotNull(incrementalSequenceBody);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testRecord() throws Exception {
|
||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
||||
SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord);
|
||||
assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size());
|
||||
Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord);
|
||||
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
assumeNotNull(arr);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBatchRecord() throws Exception {
|
||||
SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
|
||||
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
|
||||
BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord);
|
||||
assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size());
|
||||
|
||||
Base64NDArrayBody body = client.transformArray(batchCSVRecord);
|
||||
INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
assumeNotNull(arr);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static int getAvailablePort() {
|
||||
try {
|
||||
ServerSocket socket = new ServerSocket(0);
|
||||
try {
|
||||
return socket.getLocalPort();
|
||||
} finally {
|
||||
socket.close();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
|
||||
play.modules.enabled += io.skymind.skil.service.PredictionModule
|
||||
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
|
||||
play.server.pidfile.path=/tmp/RUNNING_PID
|
||||
|
||||
play.server.http.port = 9600
|
|
@ -1,63 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-spark-inference-model</artifactId>
|
||||
|
||||
<name>datavec-spark-inference-model</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-api</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-local</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,286 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.arrow.memory.BufferAllocator;
|
||||
import org.apache.arrow.memory.RootAllocator;
|
||||
import org.apache.arrow.vector.FieldVector;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.util.ndarray.RecordConverter;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.arrow.ArrowConverter;
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
|
||||
import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
|
||||
import org.datavec.local.transforms.LocalTransformExecutor;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
import static org.datavec.arrow.ArrowConverter.*;
|
||||
import static org.datavec.local.transforms.LocalTransformExecutor.execute;
|
||||
import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Slf4j
|
||||
public class CSVSparkTransform {
|
||||
@Getter
|
||||
private TransformProcess transformProcess;
|
||||
private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
|
||||
|
||||
/**
|
||||
* Convert a raw record via
|
||||
* the {@link TransformProcess}
|
||||
* to a base 64ed ndarray
|
||||
* @param batch the record to convert
|
||||
* @return teh base 64ed ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException {
|
||||
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
|
||||
bufferAllocator,transformProcess.getInitialSchema(),
|
||||
batch.getRecordsAsString()),
|
||||
transformProcess.getInitialSchema()),transformProcess);
|
||||
|
||||
ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted;
|
||||
INDArray convert = ArrowConverter.toArray(arrowRecordBatch);
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a raw record via
|
||||
* the {@link TransformProcess}
|
||||
* to a base 64ed ndarray
|
||||
* @param record the record to convert
|
||||
* @return the base 64ed ndarray
|
||||
* @throws IOException
|
||||
*/
|
||||
public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException {
|
||||
List<Writable> record2 = toArrowWritablesSingle(
|
||||
toArrowColumnsStringSingle(bufferAllocator,
|
||||
transformProcess.getInitialSchema(),record.getValues()),
|
||||
transformProcess.getInitialSchema());
|
||||
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
|
||||
INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the transform process
|
||||
* @param batch the record to transform
|
||||
* @return the transformed record
|
||||
*/
|
||||
public BatchCSVRecord transform(BatchCSVRecord batch) {
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
List<List<Writable>> converted = execute(toArrowWritables(toArrowColumnsString(
|
||||
bufferAllocator,transformProcess.getInitialSchema(),
|
||||
batch.getRecordsAsString()),
|
||||
transformProcess.getInitialSchema()),transformProcess);
|
||||
int numCols = converted.get(0).size();
|
||||
for (int row = 0; row < converted.size(); row++) {
|
||||
String[] values = new String[numCols];
|
||||
for (int i = 0; i < values.length; i++)
|
||||
values[i] = converted.get(row).get(i).toString();
|
||||
batchCSVRecord.add(new SingleCSVRecord(values));
|
||||
}
|
||||
|
||||
return batchCSVRecord;
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the transform process
|
||||
* @param record the record to transform
|
||||
* @return the transformed record
|
||||
*/
|
||||
public SingleCSVRecord transform(SingleCSVRecord record) {
|
||||
List<Writable> record2 = toArrowWritablesSingle(
|
||||
toArrowColumnsStringSingle(bufferAllocator,
|
||||
transformProcess.getInitialSchema(),record.getValues()),
|
||||
transformProcess.getInitialSchema());
|
||||
List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
|
||||
String[] values = new String[finalRecord.size()];
|
||||
for (int i = 0; i < values.length; i++)
|
||||
values[i] = finalRecord.get(i).toString();
|
||||
return new SingleCSVRecord(values);
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
||||
/**
|
||||
* Sequence schema?
|
||||
*/
|
||||
List<List<List<Writable>>> converted = executeToSequence(
|
||||
toArrowWritables(toArrowColumnsStringTimeSeries(
|
||||
bufferAllocator, transformProcess.getInitialSchema(),
|
||||
Arrays.asList(transform.getRecordsAsString())),
|
||||
transformProcess.getInitialSchema()), transformProcess);
|
||||
|
||||
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
|
||||
for (int i = 0; i < converted.size(); i++) {
|
||||
BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i));
|
||||
batchCSVRecord.add(Arrays.asList(batchCSVRecord1));
|
||||
}
|
||||
|
||||
return batchCSVRecord;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchCSVRecordSequence
|
||||
* @return
|
||||
*/
|
||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) {
|
||||
List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString();
|
||||
boolean allSameLength = true;
|
||||
Integer length = null;
|
||||
for(List<List<String>> record : recordsAsString) {
|
||||
if(length == null) {
|
||||
length = record.size();
|
||||
}
|
||||
else if(record.size() != length) {
|
||||
allSameLength = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(allSameLength) {
|
||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString);
|
||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,
|
||||
transformProcess.getInitialSchema(),
|
||||
recordsAsString.get(0).get(0).size());
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
||||
}
|
||||
|
||||
else {
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* TODO: optimize
|
||||
* @param batchCSVRecordSequence
|
||||
* @return
|
||||
*/
|
||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) {
|
||||
List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString();
|
||||
boolean allSameLength = true;
|
||||
Integer length = null;
|
||||
for(List<List<String>> record : strings) {
|
||||
if(length == null) {
|
||||
length = record.size();
|
||||
}
|
||||
else if(record.size() != length) {
|
||||
allSameLength = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(allSameLength) {
|
||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
|
||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
||||
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
|
||||
try {
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
|
||||
else {
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
||||
INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
|
||||
try {
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException(e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
||||
List<List<List<Writable>>> converted = executeToSequence(toArrowWritables(toArrowColumnsString(
|
||||
bufferAllocator,transformProcess.getInitialSchema(),
|
||||
singleCsvRecord.getRecordsAsString()),
|
||||
transformProcess.getInitialSchema()),transformProcess);
|
||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted;
|
||||
INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch);
|
||||
try {
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
|
||||
} catch (IOException e) {
|
||||
log.error("",e);
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString();
|
||||
boolean allSameLength = true;
|
||||
Integer length = null;
|
||||
for(List<List<String>> record : strings) {
|
||||
if(length == null) {
|
||||
length = record.size();
|
||||
}
|
||||
else if(record.size() != length) {
|
||||
allSameLength = false;
|
||||
}
|
||||
}
|
||||
|
||||
if(allSameLength) {
|
||||
List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
|
||||
ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
|
||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
||||
}
|
||||
|
||||
else {
|
||||
val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
|
||||
return SequenceBatchCSVRecord.fromWritables(transformed);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -1,64 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import org.datavec.image.data.ImageWritable;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@AllArgsConstructor
|
||||
public class ImageSparkTransform {
|
||||
@Getter
|
||||
private ImageTransformProcess imageTransformProcess;
|
||||
|
||||
public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException {
|
||||
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri());
|
||||
INDArray finalRecord = imageTransformProcess.executeArray(record2);
|
||||
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord));
|
||||
}
|
||||
|
||||
public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException {
|
||||
List<INDArray> records = new ArrayList<>();
|
||||
|
||||
for (SingleImageRecord imgRecord : batch.getRecords()) {
|
||||
ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri());
|
||||
INDArray finalRecord = imageTransformProcess.executeArray(record2);
|
||||
records.add(finalRecord);
|
||||
}
|
||||
|
||||
INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()]));
|
||||
|
||||
return new Base64NDArrayBody(Nd4jBase64.base64String(array));
|
||||
}
|
||||
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class Base64NDArrayBody {
|
||||
private String ndarray;
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
public class BatchCSVRecord implements Serializable {
|
||||
private List<SingleCSVRecord> records;
|
||||
|
||||
|
||||
/**
|
||||
* Get the records as a list of strings
|
||||
* (basically the underlying values for
|
||||
* {@link SingleCSVRecord})
|
||||
* @return
|
||||
*/
|
||||
public List<List<String>> getRecordsAsString() {
|
||||
if(records == null)
|
||||
records = new ArrayList<>();
|
||||
List<List<String>> ret = new ArrayList<>();
|
||||
for(SingleCSVRecord csvRecord : records) {
|
||||
ret.add(csvRecord.getValues());
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Create a batch csv record
|
||||
* from a list of writables.
|
||||
* @param batch
|
||||
* @return
|
||||
*/
|
||||
public static BatchCSVRecord fromWritables(List<List<Writable>> batch) {
|
||||
List <SingleCSVRecord> records = new ArrayList<>(batch.size());
|
||||
for(List<Writable> list : batch) {
|
||||
List<String> add = new ArrayList<>(list.size());
|
||||
for(Writable writable : list) {
|
||||
add.add(writable.toString());
|
||||
}
|
||||
records.add(new SingleCSVRecord(add));
|
||||
}
|
||||
|
||||
return BatchCSVRecord.builder().records(records).build();
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a record
|
||||
* @param record
|
||||
*/
|
||||
public void add(SingleCSVRecord record) {
|
||||
if (records == null)
|
||||
records = new ArrayList<>();
|
||||
records.add(record);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return a batch record based on a dataset
|
||||
* @param dataSet the dataset to get the batch record for
|
||||
* @return the batch record
|
||||
*/
|
||||
public static BatchCSVRecord fromDataSet(DataSet dataSet) {
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < dataSet.numExamples(); i++) {
|
||||
batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i)));
|
||||
}
|
||||
|
||||
return batchCSVRecord;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,50 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class BatchImageRecord {
|
||||
private List<SingleImageRecord> records;
|
||||
|
||||
/**
|
||||
* Add a record
|
||||
* @param record
|
||||
*/
|
||||
public void add(SingleImageRecord record) {
|
||||
if (records == null)
|
||||
records = new ArrayList<>();
|
||||
records.add(record);
|
||||
}
|
||||
|
||||
public void add(URI uri) {
|
||||
this.add(new SingleImageRecord(uri));
|
||||
}
|
||||
}
|
|
@ -1,106 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.dataset.MultiDataSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
public class SequenceBatchCSVRecord implements Serializable {
|
||||
private List<List<BatchCSVRecord>> records;
|
||||
|
||||
/**
|
||||
* Add a record
|
||||
* @param record
|
||||
*/
|
||||
public void add(List<BatchCSVRecord> record) {
|
||||
if (records == null)
|
||||
records = new ArrayList<>();
|
||||
records.add(record);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the records as a list of strings directly
|
||||
* (this basically "unpacks" the objects)
|
||||
* @return
|
||||
*/
|
||||
public List<List<List<String>>> getRecordsAsString() {
|
||||
if(records == null)
|
||||
Collections.emptyList();
|
||||
List<List<List<String>>> ret = new ArrayList<>(records.size());
|
||||
for(List<BatchCSVRecord> record : records) {
|
||||
List<List<String>> add = new ArrayList<>();
|
||||
for(BatchCSVRecord batchCSVRecord : record) {
|
||||
for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) {
|
||||
add.add(singleCSVRecord.getValues());
|
||||
}
|
||||
}
|
||||
|
||||
ret.add(add);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Convert a writables time series to a sequence batch
|
||||
* @param input
|
||||
* @return
|
||||
*/
|
||||
public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) {
|
||||
SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord();
|
||||
for(int i = 0; i < input.size(); i++) {
|
||||
ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i))));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return a batch record based on a dataset
|
||||
* @param dataSet the dataset to get the batch record for
|
||||
* @return the batch record
|
||||
*/
|
||||
public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) {
|
||||
SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
|
||||
for (int i = 0; i < dataSet.numFeatureArrays(); i++) {
|
||||
batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i)))));
|
||||
}
|
||||
|
||||
return batchCSVRecord;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SingleCSVRecord implements Serializable {
|
||||
private List<String> values;
|
||||
|
||||
/**
|
||||
* Create from an array of values uses list internally)
|
||||
* @param values
|
||||
*/
|
||||
public SingleCSVRecord(String...values) {
|
||||
this.values = Arrays.asList(values);
|
||||
}
|
||||
|
||||
/**
|
||||
* Instantiate a csv record from a vector
|
||||
* given either an input dataset and a
|
||||
* one hot matrix, the index will be appended to
|
||||
* the end of the record, or for regression
|
||||
* it will append all values in the labels
|
||||
* @param row the input vectors
|
||||
* @return the record from this {@link DataSet}
|
||||
*/
|
||||
public static SingleCSVRecord fromRow(DataSet row) {
|
||||
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
|
||||
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
|
||||
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
|
||||
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
|
||||
//classification
|
||||
SingleCSVRecord record;
|
||||
int idx = 0;
|
||||
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
|
||||
String[] values = new String[row.getFeatures().columns() + 1];
|
||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
||||
}
|
||||
int maxIdx = 0;
|
||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
||||
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
|
||||
maxIdx = i;
|
||||
}
|
||||
}
|
||||
|
||||
values[idx++] = String.valueOf(maxIdx);
|
||||
record = new SingleCSVRecord(values);
|
||||
}
|
||||
//regression (any number of values)
|
||||
else {
|
||||
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
|
||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
||||
}
|
||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
|
||||
}
|
||||
|
||||
|
||||
record = new SingleCSVRecord(values);
|
||||
|
||||
}
|
||||
return record;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,34 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.net.URI;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class SingleImageRecord {
|
||||
private URI uri;
|
||||
}
|
|
@ -1,131 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.model.service;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.model.*;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
public interface DataVecTransformService {
|
||||
|
||||
String SEQUENCE_OR_NOT_HEADER = "Sequence";
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param transformProcess
|
||||
*/
|
||||
void setCSVTransformProcess(TransformProcess transformProcess);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param imageTransformProcess
|
||||
*/
|
||||
void setImageTransformProcess(ImageTransformProcess imageTransformProcess);
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
TransformProcess getCSVTransformProcess();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ImageTransformProcess getImageTransformProcess();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord);
|
||||
|
||||
SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
BatchCSVRecord transform(BatchCSVRecord batchCSVRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param singleImageRecord
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchImageRecord
|
||||
* @return
|
||||
* @throws IOException
|
||||
*/
|
||||
Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform);
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.datavec.spark.transform";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseND4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class BatchCSVRecordTest {
|
||||
|
||||
@Test
|
||||
public void testBatchRecordCreationFromDataSet() {
|
||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
|
||||
|
||||
BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet);
|
||||
assertEquals(2, batchCSVRecord.getRecords().size());
|
||||
}
|
||||
|
||||
}
|
|
@ -1,212 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.api.transform.transform.integer.BaseIntegerTransform;
|
||||
import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
|
||||
import org.datavec.api.writable.DoubleWritable;
|
||||
import org.datavec.api.writable.Text;
|
||||
import org.datavec.api.writable.Writable;
|
||||
import org.datavec.spark.inference.model.CSVSparkTransform;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
|
||||
public class CSVSparkTransformTest {
|
||||
@Test
|
||||
public void testTransformer() throws Exception {
|
||||
List<Writable> input = new ArrayList<>();
|
||||
input.add(new DoubleWritable(1.0));
|
||||
input.add(new DoubleWritable(2.0));
|
||||
|
||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
List<Writable> output = new ArrayList<>();
|
||||
output.add(new Text("1.0"));
|
||||
output.add(new Text("2.0"));
|
||||
|
||||
TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
||||
Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
assertTrue(fromBase64.isVector());
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTransformerBatch() throws Exception {
|
||||
List<Writable> input = new ArrayList<>();
|
||||
input.add(new DoubleWritable(1.0));
|
||||
input.add(new DoubleWritable(2.0));
|
||||
|
||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
List<Writable> output = new ArrayList<>();
|
||||
output.add(new Text("1.0"));
|
||||
output.add(new Text("2.0"));
|
||||
|
||||
TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < 3; i++)
|
||||
batchCSVRecord.add(record);
|
||||
//data type is string, unable to convert
|
||||
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
|
||||
/* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
assertTrue(fromBase64.isMatrix());
|
||||
System.out.println("Base 64ed array " + fromBase64); */
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testSingleBatchSequence() throws Exception {
|
||||
List<Writable> input = new ArrayList<>();
|
||||
input.add(new DoubleWritable(1.0));
|
||||
input.add(new DoubleWritable(2.0));
|
||||
|
||||
Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
List<Writable> output = new ArrayList<>();
|
||||
output.add(new Text("1.0"));
|
||||
output.add(new Text("2.0"));
|
||||
|
||||
TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
|
||||
CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < 3; i++)
|
||||
batchCSVRecord.add(record);
|
||||
BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
|
||||
SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
|
||||
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
|
||||
Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
|
||||
INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray());
|
||||
|
||||
|
||||
//ensure accumulation
|
||||
sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
|
||||
sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
|
||||
assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape());
|
||||
|
||||
SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
|
||||
assertNotNull(transformed.getRecords());
|
||||
// System.out.println(transformed);
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testSpecificSequence() throws Exception {
|
||||
final Schema schema = new Schema.Builder()
|
||||
.addColumnsString("action")
|
||||
.build();
|
||||
|
||||
final TransformProcess transformProcess = new TransformProcess.Builder(schema)
|
||||
.removeAllColumnsExceptFor("action")
|
||||
.transform(new ConverToLowercase("action"))
|
||||
.convertToSequence()
|
||||
.transform(new TextToCharacterIndexTransform("action", "action_sequence",
|
||||
defaultCharIndex(), false))
|
||||
.integerToOneHot("action_sequence",0,29)
|
||||
.build();
|
||||
|
||||
final String[] data1 = new String[] { "test1" };
|
||||
final String[] data2 = new String[] { "test2" };
|
||||
final BatchCSVRecord batchCsvRecord = new BatchCSVRecord(
|
||||
Arrays.asList(
|
||||
new SingleCSVRecord(data1),
|
||||
new SingleCSVRecord(data2)));
|
||||
|
||||
final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
|
||||
// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
|
||||
transform.transformSequenceIncremental(batchCsvRecord);
|
||||
assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
|
||||
|
||||
}
|
||||
|
||||
private static Map<Character,Integer> defaultCharIndex() {
|
||||
Map<Character,Integer> ret = new TreeMap<>();
|
||||
|
||||
ret.put('a',0);
|
||||
ret.put('b',1);
|
||||
ret.put('c',2);
|
||||
ret.put('d',3);
|
||||
ret.put('e',4);
|
||||
ret.put('f',5);
|
||||
ret.put('g',6);
|
||||
ret.put('h',7);
|
||||
ret.put('i',8);
|
||||
ret.put('j',9);
|
||||
ret.put('k',10);
|
||||
ret.put('l',11);
|
||||
ret.put('m',12);
|
||||
ret.put('n',13);
|
||||
ret.put('o',14);
|
||||
ret.put('p',15);
|
||||
ret.put('q',16);
|
||||
ret.put('r',17);
|
||||
ret.put('s',18);
|
||||
ret.put('t',19);
|
||||
ret.put('u',20);
|
||||
ret.put('v',21);
|
||||
ret.put('w',22);
|
||||
ret.put('x',23);
|
||||
ret.put('y',24);
|
||||
ret.put('z',25);
|
||||
ret.put('/',26);
|
||||
ret.put(' ',27);
|
||||
ret.put('(',28);
|
||||
ret.put(')',29);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
public static class ConverToLowercase extends BaseIntegerTransform {
|
||||
public ConverToLowercase(String column) {
|
||||
super(column);
|
||||
}
|
||||
|
||||
public Text map(Writable writable) {
|
||||
return new Text(writable.toString().toLowerCase());
|
||||
}
|
||||
|
||||
public Object map(Object input) {
|
||||
return new Text(input.toString().toLowerCase());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,86 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.ImageSparkTransform;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ImageSparkTransformTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testSingleImageSparkTransform() throws Exception {
|
||||
int seed = 12345;
|
||||
|
||||
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
|
||||
|
||||
SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI());
|
||||
|
||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
|
||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
||||
|
||||
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
|
||||
Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
|
||||
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
assertEquals(1, fromBase64.size(0));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testBatchImageSparkTransform() throws Exception {
|
||||
int seed = 12345;
|
||||
|
||||
File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
|
||||
File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile();
|
||||
File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile();
|
||||
|
||||
BatchImageRecord batch = new BatchImageRecord();
|
||||
batch.add(f0.toURI());
|
||||
batch.add(f1.toURI());
|
||||
batch.add(f2.toURI());
|
||||
|
||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
|
||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
||||
|
||||
ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
|
||||
Base64NDArrayBody body = imgSparkTransform.toArray(batch);
|
||||
|
||||
INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
|
||||
// System.out.println("Base 64ed array " + fromBase64);
|
||||
assertEquals(3, fromBase64.size(0));
|
||||
}
|
||||
}
|
|
@ -1,60 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.fail;
|
||||
|
||||
public class SingleCSVRecordTest {
|
||||
|
||||
@Test(expected = IllegalArgumentException.class)
|
||||
public void testVectorAssertion() {
|
||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1));
|
||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet);
|
||||
fail(singleCsvRecord.toString() + " should have thrown an exception");
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorOneHotLabel() {
|
||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}}));
|
||||
|
||||
//assert
|
||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
|
||||
assertEquals(3, singleCsvRecord.getValues().size());
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testVectorRegression() {
|
||||
DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
|
||||
|
||||
//assert
|
||||
SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
|
||||
assertEquals(4, singleCsvRecord.getValues().size());
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,47 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
|
||||
import java.io.File;
|
||||
|
||||
public class SingleImageRecordTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testImageRecord() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f);
|
||||
File f0 = new File(f, "class0/0.jpg");
|
||||
File f1 = new File(f, "/class1/A.jpg");
|
||||
|
||||
SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI());
|
||||
|
||||
// need jackson test?
|
||||
}
|
||||
}
|
|
@ -1,154 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-spark-inference-server_2.11</artifactId>
|
||||
|
||||
<name>datavec-spark-inference-server</name>
|
||||
|
||||
<properties>
|
||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
||||
<scala.version>2.11.12</scala.version>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
<maven.compiler.source>1.8</maven.compiler.source>
|
||||
<maven.compiler.target>1.8</maven.compiler.target>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark-inference-model</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-spark_2.11</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.commons</groupId>
|
||||
<artifactId>commons-lang3</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.hibernate</groupId>
|
||||
<artifactId>hibernate-validator</artifactId>
|
||||
<version>${hibernate.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-library</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-java_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
<exclusions>
|
||||
<exclusion>
|
||||
<groupId>com.google.code.findbugs</groupId>
|
||||
<artifactId>jsr305</artifactId>
|
||||
</exclusion>
|
||||
<exclusion>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
</exclusion>
|
||||
</exclusions>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>net.jodah</groupId>
|
||||
<artifactId>typetools</artifactId>
|
||||
<version>${jodah.typetools.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-json_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-server_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.play</groupId>
|
||||
<artifactId>play-netty-server_2.11</artifactId>
|
||||
<version>${playframework.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.typesafe.akka</groupId>
|
||||
<artifactId>akka-cluster_2.11</artifactId>
|
||||
<version>2.5.23</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
<version>${jcommander.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_2.11</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,352 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.server;
|
||||
|
||||
import com.beust.jcommander.JCommander;
|
||||
import com.beust.jcommander.ParameterException;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.CSVSparkTransform;
|
||||
import org.datavec.spark.inference.model.model.*;
|
||||
import play.BuiltInComponents;
|
||||
import play.Mode;
|
||||
import play.routing.Router;
|
||||
import play.routing.RoutingDsl;
|
||||
import play.server.Server;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.Base64;
|
||||
import java.util.Random;
|
||||
|
||||
import static play.mvc.Results.*;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
public class CSVSparkTransformServer extends SparkTransformServer {
|
||||
private CSVSparkTransform transform;
|
||||
|
||||
public void runMain(String[] args) throws Exception {
|
||||
JCommander jcmdr = new JCommander(this);
|
||||
|
||||
try {
|
||||
jcmdr.parse(args);
|
||||
} catch (ParameterException e) {
|
||||
//User provides invalid input -> print the usage info
|
||||
jcmdr.usage();
|
||||
if (jsonPath == null)
|
||||
System.err.println("Json path parameter is missing.");
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
} catch (Exception e2) {
|
||||
}
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
if (jsonPath != null) {
|
||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||
TransformProcess transformProcess = TransformProcess.fromJson(json);
|
||||
transform = new CSVSparkTransform(transformProcess);
|
||||
} else {
|
||||
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
|
||||
+ "to /transformprocess");
|
||||
}
|
||||
|
||||
//Set play secret key, if required
|
||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||
String crypto = System.getProperty("play.crypto.secret");
|
||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
|
||||
byte[] newCrypto = new byte[1024];
|
||||
|
||||
new Random().nextBytes(newCrypto);
|
||||
|
||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
||||
System.setProperty("play.crypto.secret", base64);
|
||||
}
|
||||
|
||||
|
||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||
}
|
||||
|
||||
protected Router createRouter(BuiltInComponents b){
|
||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
|
||||
|
||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
if (transform == null)
|
||||
return badRequest();
|
||||
return ok(transform.getTransformProcess().toJson()).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in GET /transformprocess",e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
|
||||
setCSVTransformProcess(transformProcess);
|
||||
log.info("Transform process initialized");
|
||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in POST /transformprocess",e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincremental").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformincremental", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformincremental", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transform").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transform", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
BatchCSVRecord batch = transform(input);
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(batch)).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transform", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformincrementalarray", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformincrementalarray", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||
if (isSequence(req)) {
|
||||
try {
|
||||
SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
|
||||
if (batchCSVRecord == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformarray", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
|
||||
if (batchCSVRecord == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("Error in /transformarray", e);
|
||||
return internalServerError(e.getMessage());
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
return routingDsl.build();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new CSVSparkTransformServer().runMain(args);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transformProcess
|
||||
*/
|
||||
@Override
|
||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
||||
this.transform = new CSVSparkTransform(transformProcess);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
||||
log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass());
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
/**
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public TransformProcess getCSVTransformProcess() {
|
||||
return transform.getTransformProcess();
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageTransformProcess getImageTransformProcess() {
|
||||
log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass());
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
/**
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
||||
return this.transform.transformSequenceIncremental(transform);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
return transform.transformSequence(batchCSVRecord);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
return this.transform.transformSequenceArray(batchCSVRecord);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
||||
return this.transform.transformSequenceArrayIncremental(singleCsvRecord);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transform
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
|
||||
return this.transform.transform(transform);
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
return this.transform.transform(batchCSVRecord);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
||||
return transform.transform(batchCSVRecord);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param batchCSVRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
||||
try {
|
||||
return this.transform.toArray(batchCSVRecord);
|
||||
} catch (IOException e) {
|
||||
log.error("Error in transformArray",e);
|
||||
throw new IllegalStateException("Transform array shouldn't throw exception");
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @param singleCsvRecord
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
||||
try {
|
||||
return this.transform.toArray(singleCsvRecord);
|
||||
} catch (IOException e) {
|
||||
log.error("Error in transformArrayIncremental",e);
|
||||
throw new IllegalStateException("Transform array shouldn't throw exception");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
|
||||
log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass());
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
|
||||
log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass());
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
}
|
|
@ -1,261 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.server;
|
||||
|
||||
import com.beust.jcommander.JCommander;
|
||||
import com.beust.jcommander.ParameterException;
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.model.ImageSparkTransform;
|
||||
import org.datavec.spark.inference.model.model.*;
|
||||
import play.BuiltInComponents;
|
||||
import play.Mode;
|
||||
import play.libs.Files;
|
||||
import play.mvc.Http;
|
||||
import play.routing.Router;
|
||||
import play.routing.RoutingDsl;
|
||||
import play.server.Server;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import static play.mvc.Results.*;
|
||||
|
||||
@Slf4j
|
||||
@Data
|
||||
public class ImageSparkTransformServer extends SparkTransformServer {
|
||||
private ImageSparkTransform transform;
|
||||
|
||||
public void runMain(String[] args) throws Exception {
|
||||
JCommander jcmdr = new JCommander(this);
|
||||
|
||||
try {
|
||||
jcmdr.parse(args);
|
||||
} catch (ParameterException e) {
|
||||
//User provides invalid input -> print the usage info
|
||||
jcmdr.usage();
|
||||
if (jsonPath == null)
|
||||
System.err.println("Json path parameter is missing.");
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
} catch (Exception e2) {
|
||||
}
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
if (jsonPath != null) {
|
||||
String json = FileUtils.readFileToString(new File(jsonPath));
|
||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
|
||||
transform = new ImageSparkTransform(transformProcess);
|
||||
} else {
|
||||
log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
|
||||
+ "to /transformprocess");
|
||||
}
|
||||
|
||||
server = Server.forRouter(Mode.PROD, port, this::createRouter);
|
||||
}
|
||||
|
||||
protected Router createRouter(BuiltInComponents builtInComponents){
|
||||
RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
|
||||
|
||||
routingDsl.GET("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
if (transform == null)
|
||||
return badRequest();
|
||||
log.info("Transform process initialized");
|
||||
return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformprocess").routingTo(req -> {
|
||||
try {
|
||||
ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
|
||||
setImageTransformProcess(transformProcess);
|
||||
log.info("Transform process initialized");
|
||||
return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincrementalarray").routingTo(req -> {
|
||||
try {
|
||||
SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
|
||||
if (record == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformincrementalimage").routingTo(req -> {
|
||||
try {
|
||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||
if (files.isEmpty() || files.get(0).getRef() == null ) {
|
||||
return badRequest();
|
||||
}
|
||||
|
||||
File file = files.get(0).getRef().path().toFile();
|
||||
SingleImageRecord record = new SingleImageRecord(file.toURI());
|
||||
|
||||
return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformarray").routingTo(req -> {
|
||||
try {
|
||||
BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
|
||||
if (batch == null)
|
||||
return badRequest();
|
||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
routingDsl.POST("/transformimage").routingTo(req -> {
|
||||
try {
|
||||
Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData();
|
||||
List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles();
|
||||
if (files.size() == 0) {
|
||||
return badRequest();
|
||||
}
|
||||
|
||||
List<SingleImageRecord> records = new ArrayList<>();
|
||||
|
||||
for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) {
|
||||
Files.TemporaryFile file = filePart.getRef();
|
||||
if (file != null) {
|
||||
SingleImageRecord record = new SingleImageRecord(file.path().toUri());
|
||||
records.add(record);
|
||||
}
|
||||
}
|
||||
|
||||
BatchImageRecord batch = new BatchImageRecord(records);
|
||||
|
||||
return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
|
||||
} catch (Exception e) {
|
||||
log.error("",e);
|
||||
return internalServerError();
|
||||
}
|
||||
});
|
||||
|
||||
return routingDsl.build();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
|
||||
throw new UnsupportedOperationException();
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
|
||||
throw new UnsupportedOperationException();
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setCSVTransformProcess(TransformProcess transformProcess) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
|
||||
this.transform = new ImageSparkTransform(imageTransformProcess);
|
||||
}
|
||||
|
||||
@Override
|
||||
public TransformProcess getCSVTransformProcess() {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public ImageTransformProcess getImageTransformProcess() {
|
||||
return transform.getImageTransformProcess();
|
||||
}
|
||||
|
||||
@Override
|
||||
public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
|
||||
throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException {
|
||||
return transform.toArray(record);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException {
|
||||
return transform.toArray(batch);
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new ImageSparkTransformServer().runMain(args);
|
||||
}
|
||||
}
|
|
@ -1,67 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.server;
|
||||
|
||||
import com.beust.jcommander.Parameter;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.service.DataVecTransformService;
|
||||
import org.nd4j.shade.jackson.databind.ObjectMapper;
|
||||
import play.mvc.Http;
|
||||
import play.server.Server;
|
||||
|
||||
public abstract class SparkTransformServer implements DataVecTransformService {
|
||||
@Parameter(names = {"-j", "--jsonPath"}, arity = 1)
|
||||
protected String jsonPath = null;
|
||||
@Parameter(names = {"-dp", "--dataVecPort"}, arity = 1)
|
||||
protected int port = 9000;
|
||||
@Parameter(names = {"-dt", "--dataType"}, arity = 1)
|
||||
private TransformDataType transformDataType = null;
|
||||
protected Server server;
|
||||
protected static ObjectMapper objectMapper = new ObjectMapper();
|
||||
protected static String contentType = "application/json";
|
||||
|
||||
public abstract void runMain(String[] args) throws Exception;
|
||||
|
||||
/**
|
||||
* Stop the server
|
||||
*/
|
||||
public void stop() {
|
||||
if (server != null)
|
||||
server.stop();
|
||||
}
|
||||
|
||||
protected boolean isSequence(Http.Request request) {
|
||||
return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
|
||||
&& request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
|
||||
}
|
||||
|
||||
protected String getJsonText(Http.Request request) {
|
||||
JsonNode tryJson = request.body().asJson();
|
||||
if (tryJson != null)
|
||||
return tryJson.toString();
|
||||
else
|
||||
return request.body().asText();
|
||||
}
|
||||
|
||||
public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
|
||||
}
|
|
@ -1,76 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.server;
|
||||
|
||||
import lombok.Data;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
|
||||
import java.io.InvalidClassException;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Slf4j
|
||||
public class SparkTransformServerChooser {
|
||||
private SparkTransformServer sparkTransformServer = null;
|
||||
private TransformDataType transformDataType = null;
|
||||
|
||||
public void runMain(String[] args) throws Exception {
|
||||
|
||||
int pos = getMatchingPosition(args, "-dt", "--dataType");
|
||||
if (pos == -1) {
|
||||
log.error("no valid options");
|
||||
log.error("-dt, --dataType Options: [CSV, IMAGE]");
|
||||
throw new Exception("no valid options");
|
||||
} else {
|
||||
transformDataType = TransformDataType.valueOf(args[pos + 1]);
|
||||
}
|
||||
|
||||
switch (transformDataType) {
|
||||
case CSV:
|
||||
sparkTransformServer = new CSVSparkTransformServer();
|
||||
break;
|
||||
case IMAGE:
|
||||
sparkTransformServer = new ImageSparkTransformServer();
|
||||
break;
|
||||
default:
|
||||
throw new InvalidClassException("no matching SparkTransform class");
|
||||
}
|
||||
|
||||
sparkTransformServer.runMain(args);
|
||||
}
|
||||
|
||||
private int getMatchingPosition(String[] args, String... options) {
|
||||
List optionList = Arrays.asList(options);
|
||||
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
if (optionList.contains(args[i])) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
new SparkTransformServerChooser().runMain(args);
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.inference.server;
|
||||
|
||||
public enum TransformDataType {
|
||||
CSV, IMAGE,
|
||||
}
|
|
@ -1,350 +0,0 @@
|
|||
# This is the main configuration file for the application.
|
||||
# https://www.playframework.com/documentation/latest/ConfigFile
|
||||
# ~~~~~
|
||||
# Play uses HOCON as its configuration file format. HOCON has a number
|
||||
# of advantages over other config formats, but there are two things that
|
||||
# can be used when modifying settings.
|
||||
#
|
||||
# You can include other configuration files in this main application.conf file:
|
||||
#include "extra-config.conf"
|
||||
#
|
||||
# You can declare variables and substitute for them:
|
||||
#mykey = ${some.value}
|
||||
#
|
||||
# And if an environment variable exists when there is no other subsitution, then
|
||||
# HOCON will fall back to substituting environment variable:
|
||||
#mykey = ${JAVA_HOME}
|
||||
|
||||
## Akka
|
||||
# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
|
||||
# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
|
||||
# ~~~~~
|
||||
# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
|
||||
# other streaming HTTP responses.
|
||||
akka {
|
||||
# "akka.log-config-on-start" is extraordinarly useful because it log the complete
|
||||
# configuration at INFO level, including defaults and overrides, so it s worth
|
||||
# putting at the very top.
|
||||
#
|
||||
# Put the following in your conf/logback.xml file:
|
||||
#
|
||||
# <logger name="akka.actor" level="INFO" />
|
||||
#
|
||||
# And then uncomment this line to debug the configuration.
|
||||
#
|
||||
#log-config-on-start = true
|
||||
}
|
||||
|
||||
## Modules
|
||||
# https://www.playframework.com/documentation/latest/Modules
|
||||
# ~~~~~
|
||||
# Control which modules are loaded when Play starts. Note that modules are
|
||||
# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
|
||||
# Please see https://www.playframework.com/documentation/latest/GlobalSettings
|
||||
# for more information.
|
||||
#
|
||||
# You can also extend Play functionality by using one of the publically available
|
||||
# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
|
||||
play.modules {
|
||||
# By default, Play will load any class called Module that is defined
|
||||
# in the root package (the "app" directory), or you can define them
|
||||
# explicitly below.
|
||||
# If there are any built-in modules that you want to disable, you can list them here.
|
||||
#enabled += my.application.Module
|
||||
|
||||
# If there are any built-in modules that you want to disable, you can list them here.
|
||||
#disabled += ""
|
||||
}
|
||||
|
||||
## Internationalisation
|
||||
# https://www.playframework.com/documentation/latest/JavaI18N
|
||||
# https://www.playframework.com/documentation/latest/ScalaI18N
|
||||
# ~~~~~
|
||||
# Play comes with its own i18n settings, which allow the user's preferred language
|
||||
# to map through to internal messages, or allow the language to be stored in a cookie.
|
||||
play.i18n {
|
||||
# The application languages
|
||||
langs = [ "en" ]
|
||||
|
||||
# Whether the language cookie should be secure or not
|
||||
#langCookieSecure = true
|
||||
|
||||
# Whether the HTTP only attribute of the cookie should be set to true
|
||||
#langCookieHttpOnly = true
|
||||
}
|
||||
|
||||
## Play HTTP settings
|
||||
# ~~~~~
|
||||
play.http {
|
||||
## Router
|
||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||
# ~~~~~
|
||||
# Define the Router object to use for this application.
|
||||
# This router will be looked up first when the application is starting up,
|
||||
# so make sure this is the entry point.
|
||||
# Furthermore, it's assumed your route file is named properly.
|
||||
# So for an application router like `my.application.Router`,
|
||||
# you may need to define a router file `conf/my.application.routes`.
|
||||
# Default to Routes in the root package (aka "apps" folder) (and conf/routes)
|
||||
#router = my.application.Router
|
||||
|
||||
## Action Creator
|
||||
# https://www.playframework.com/documentation/latest/JavaActionCreator
|
||||
# ~~~~~
|
||||
#actionCreator = null
|
||||
|
||||
## ErrorHandler
|
||||
# https://www.playframework.com/documentation/latest/JavaRouting
|
||||
# https://www.playframework.com/documentation/latest/ScalaRouting
|
||||
# ~~~~~
|
||||
# If null, will attempt to load a class called ErrorHandler in the root package,
|
||||
#errorHandler = null
|
||||
|
||||
## Filters
|
||||
# https://www.playframework.com/documentation/latest/ScalaHttpFilters
|
||||
# https://www.playframework.com/documentation/latest/JavaHttpFilters
|
||||
# ~~~~~
|
||||
# Filters run code on every request. They can be used to perform
|
||||
# common logic for all your actions, e.g. adding common headers.
|
||||
# Defaults to "Filters" in the root package (aka "apps" folder)
|
||||
# Alternatively you can explicitly register a class here.
|
||||
#filters += my.application.Filters
|
||||
|
||||
## Session & Flash
|
||||
# https://www.playframework.com/documentation/latest/JavaSessionFlash
|
||||
# https://www.playframework.com/documentation/latest/ScalaSessionFlash
|
||||
# ~~~~~
|
||||
session {
|
||||
# Sets the cookie to be sent only over HTTPS.
|
||||
#secure = true
|
||||
|
||||
# Sets the cookie to be accessed only by the server.
|
||||
#httpOnly = true
|
||||
|
||||
# Sets the max-age field of the cookie to 5 minutes.
|
||||
# NOTE: this only sets when the browser will discard the cookie. Play will consider any
|
||||
# cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
|
||||
# you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
|
||||
#maxAge = 300
|
||||
|
||||
# Sets the domain on the session cookie.
|
||||
#domain = "example.com"
|
||||
}
|
||||
|
||||
flash {
|
||||
# Sets the cookie to be sent only over HTTPS.
|
||||
#secure = true
|
||||
|
||||
# Sets the cookie to be accessed only by the server.
|
||||
#httpOnly = true
|
||||
}
|
||||
}
|
||||
|
||||
## Netty Provider
|
||||
# https://www.playframework.com/documentation/latest/SettingsNetty
|
||||
# ~~~~~
|
||||
play.server.netty {
|
||||
# Whether the Netty wire should be logged
|
||||
#log.wire = true
|
||||
|
||||
# If you run Play on Linux, you can use Netty's native socket transport
|
||||
# for higher performance with less garbage.
|
||||
#transport = "native"
|
||||
}
|
||||
|
||||
## WS (HTTP Client)
|
||||
# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
|
||||
# ~~~~~
|
||||
# The HTTP client primarily used for REST APIs. The default client can be
|
||||
# configured directly, but you can also create different client instances
|
||||
# with customized settings. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += ws // or javaWs if using java
|
||||
#
|
||||
play.ws {
|
||||
# Sets HTTP requests not to follow 302 requests
|
||||
#followRedirects = false
|
||||
|
||||
# Sets the maximum number of open HTTP connections for the client.
|
||||
#ahc.maxConnectionsTotal = 50
|
||||
|
||||
## WS SSL
|
||||
# https://www.playframework.com/documentation/latest/WsSSL
|
||||
# ~~~~~
|
||||
ssl {
|
||||
# Configuring HTTPS with Play WS does not require programming. You can
|
||||
# set up both trustManager and keyManager for mutual authentication, and
|
||||
# turn on JSSE debugging in development with a reload.
|
||||
#debug.handshake = true
|
||||
#trustManager = {
|
||||
# stores = [
|
||||
# { type = "JKS", path = "exampletrust.jks" }
|
||||
# ]
|
||||
#}
|
||||
}
|
||||
}
|
||||
|
||||
## Cache
|
||||
# https://www.playframework.com/documentation/latest/JavaCache
|
||||
# https://www.playframework.com/documentation/latest/ScalaCache
|
||||
# ~~~~~
|
||||
# Play comes with an integrated cache API that can reduce the operational
|
||||
# overhead of repeated requests. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += cache
|
||||
#
|
||||
play.cache {
|
||||
# If you want to bind several caches, you can bind the individually
|
||||
#bindCaches = ["db-cache", "user-cache", "session-cache"]
|
||||
}
|
||||
|
||||
## Filters
|
||||
# https://www.playframework.com/documentation/latest/Filters
|
||||
# ~~~~~
|
||||
# There are a number of built-in filters that can be enabled and configured
|
||||
# to give Play greater security. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += filters
|
||||
#
|
||||
play.filters {
|
||||
## CORS filter configuration
|
||||
# https://www.playframework.com/documentation/latest/CorsFilter
|
||||
# ~~~~~
|
||||
# CORS is a protocol that allows web applications to make requests from the browser
|
||||
# across different domains.
|
||||
# NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
|
||||
# dependencies on CORS settings.
|
||||
cors {
|
||||
# Filter paths by a whitelist of path prefixes
|
||||
#pathPrefixes = ["/some/path", ...]
|
||||
|
||||
# The allowed origins. If null, all origins are allowed.
|
||||
#allowedOrigins = ["http://www.example.com"]
|
||||
|
||||
# The allowed HTTP methods. If null, all methods are allowed
|
||||
#allowedHttpMethods = ["GET", "POST"]
|
||||
}
|
||||
|
||||
## CSRF Filter
|
||||
# https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
|
||||
# https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
|
||||
# ~~~~~
|
||||
# Play supports multiple methods for verifying that a request is not a CSRF request.
|
||||
# The primary mechanism is a CSRF token. This token gets placed either in the query string
|
||||
# or body of every form submitted, and also gets placed in the users session.
|
||||
# Play then verifies that both tokens are present and match.
|
||||
csrf {
|
||||
# Sets the cookie to be sent only over HTTPS
|
||||
#cookie.secure = true
|
||||
|
||||
# Defaults to CSRFErrorHandler in the root package.
|
||||
#errorHandler = MyCSRFErrorHandler
|
||||
}
|
||||
|
||||
## Security headers filter configuration
|
||||
# https://www.playframework.com/documentation/latest/SecurityHeaders
|
||||
# ~~~~~
|
||||
# Defines security headers that prevent XSS attacks.
|
||||
# If enabled, then all options are set to the below configuration by default:
|
||||
headers {
|
||||
# The X-Frame-Options header. If null, the header is not set.
|
||||
#frameOptions = "DENY"
|
||||
|
||||
# The X-XSS-Protection header. If null, the header is not set.
|
||||
#xssProtection = "1; mode=block"
|
||||
|
||||
# The X-Content-Type-Options header. If null, the header is not set.
|
||||
#contentTypeOptions = "nosniff"
|
||||
|
||||
# The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
|
||||
#permittedCrossDomainPolicies = "master-only"
|
||||
|
||||
# The Content-Security-Policy header. If null, the header is not set.
|
||||
#contentSecurityPolicy = "default-src 'self'"
|
||||
}
|
||||
|
||||
## Allowed hosts filter configuration
|
||||
# https://www.playframework.com/documentation/latest/AllowedHostsFilter
|
||||
# ~~~~~
|
||||
# Play provides a filter that lets you configure which hosts can access your application.
|
||||
# This is useful to prevent cache poisoning attacks.
|
||||
hosts {
|
||||
# Allow requests to example.com, its subdomains, and localhost:9000.
|
||||
#allowed = [".example.com", "localhost:9000"]
|
||||
}
|
||||
}
|
||||
|
||||
## Evolutions
|
||||
# https://www.playframework.com/documentation/latest/Evolutions
|
||||
# ~~~~~
|
||||
# Evolutions allows database scripts to be automatically run on startup in dev mode
|
||||
# for database migrations. You must enable this by adding to build.sbt:
|
||||
#
|
||||
# libraryDependencies += evolutions
|
||||
#
|
||||
play.evolutions {
|
||||
# You can disable evolutions for a specific datasource if necessary
|
||||
#db.default.enabled = false
|
||||
}
|
||||
|
||||
## Database Connection Pool
|
||||
# https://www.playframework.com/documentation/latest/SettingsJDBC
|
||||
# ~~~~~
|
||||
# Play doesn't require a JDBC database to run, but you can easily enable one.
|
||||
#
|
||||
# libraryDependencies += jdbc
|
||||
#
|
||||
play.db {
|
||||
# The combination of these two settings results in "db.default" as the
|
||||
# default JDBC pool:
|
||||
#config = "db"
|
||||
#default = "default"
|
||||
|
||||
# Play uses HikariCP as the default connection pool. You can override
|
||||
# settings by changing the prototype:
|
||||
prototype {
|
||||
# Sets a fixed JDBC connection pool size of 50
|
||||
#hikaricp.minimumIdle = 50
|
||||
#hikaricp.maximumPoolSize = 50
|
||||
}
|
||||
}
|
||||
|
||||
## JDBC Datasource
|
||||
# https://www.playframework.com/documentation/latest/JavaDatabase
|
||||
# https://www.playframework.com/documentation/latest/ScalaDatabase
|
||||
# ~~~~~
|
||||
# Once JDBC datasource is set up, you can work with several different
|
||||
# database options:
|
||||
#
|
||||
# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
|
||||
# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
|
||||
# EBean: https://playframework.com/documentation/latest/JavaEbean
|
||||
# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
|
||||
#
|
||||
db {
|
||||
# You can declare as many datasources as you want.
|
||||
# By convention, the default datasource is named `default`
|
||||
|
||||
# https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
|
||||
default.driver = org.h2.Driver
|
||||
default.url = "jdbc:h2:mem:play"
|
||||
#default.username = sa
|
||||
#default.password = ""
|
||||
|
||||
# You can expose this datasource via JNDI if needed (Useful for JPA)
|
||||
default.jndiName=DefaultDS
|
||||
|
||||
# You can turn on SQL logging for any datasource
|
||||
# https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
|
||||
#default.logSql=true
|
||||
}
|
||||
|
||||
jpa.default=defaultPersistenceUnit
|
||||
|
||||
|
||||
#Increase default maximum post length - used for remote listener functionality
|
||||
#Can get response 413 with larger networks without setting this
|
||||
# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
|
||||
#parsers.text.maxLength=10M
|
||||
play.http.parser.maxMemoryBuffer=10M
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.nd4j.common.tests.AbstractAssertTestsClass;
|
||||
import org.nd4j.common.tests.BaseND4JTest;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
|
||||
|
||||
@Override
|
||||
protected Set<Class<?>> getExclusions() {
|
||||
//Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
|
||||
return new HashSet<>();
|
||||
}
|
||||
|
||||
@Override
|
||||
protected String getPackageName() {
|
||||
return "org.datavec.spark.transform";
|
||||
}
|
||||
|
||||
@Override
|
||||
protected Class<?> getBaseClass() {
|
||||
return BaseND4JTest.class;
|
||||
}
|
||||
}
|
|
@ -1,127 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.junit.Assume.assumeNotNull;
|
||||
|
||||
public class CSVSparkTransformServerNoJsonTest {
|
||||
|
||||
private static CSVSparkTransformServer server;
|
||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
private static TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
||||
|
||||
@BeforeClass
|
||||
public static void before() throws Exception {
|
||||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
|
||||
// Only one time
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"-dp", "9050"});
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void after() throws Exception {
|
||||
fileSave.delete();
|
||||
server.stop();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testServer() throws Exception {
|
||||
assertTrue(server.getTransform() == null);
|
||||
JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(transformProcess.toJson()).asJson().getBody();
|
||||
assumeNotNull(server.getTransform());
|
||||
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
||||
JsonNode jsonNode =
|
||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(SingleCSVRecord.class).getBody();
|
||||
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < 3; i++)
|
||||
batchCSVRecord.add(singleCsvRecord);
|
||||
/* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
||||
|
||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
*/
|
||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,121 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.spark.inference.server.CSVSparkTransformServer;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchCSVRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleCSVRecord;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.UUID;
|
||||
|
||||
public class CSVSparkTransformServerTest {
|
||||
|
||||
private static CSVSparkTransformServer server;
|
||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
private static TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
|
||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
||||
|
||||
@BeforeClass
|
||||
public static void before() throws Exception {
|
||||
server = new CSVSparkTransformServer();
|
||||
FileUtils.write(fileSave, transformProcess.toJson());
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void after() throws Exception {
|
||||
fileSave.deleteOnExit();
|
||||
server.stop();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testServer() throws Exception {
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
||||
JsonNode jsonNode =
|
||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(SingleCSVRecord.class).getBody();
|
||||
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < 3; i++)
|
||||
batchCSVRecord.add(singleCsvRecord);
|
||||
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
||||
|
||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
}
|
|
@ -1,164 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.server.ImageSparkTransformServer;
|
||||
import org.datavec.spark.inference.model.model.Base64NDArrayBody;
|
||||
import org.datavec.spark.inference.model.model.BatchImageRecord;
|
||||
import org.datavec.spark.inference.model.model.SingleImageRecord;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class ImageSparkTransformServerTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
private static ImageSparkTransformServer server;
|
||||
private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
|
||||
|
||||
@BeforeClass
|
||||
public static void before() throws Exception {
|
||||
server = new ImageSparkTransformServer();
|
||||
|
||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
|
||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
||||
|
||||
FileUtils.write(fileSave, imgTransformProcess.toJson());
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"});
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void after() throws Exception {
|
||||
fileSave.deleteOnExit();
|
||||
server.stop();
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testImageServer() throws Exception {
|
||||
SingleImageRecord record =
|
||||
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asJson().getBody();
|
||||
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
BatchImageRecord batch = new BatchImageRecord();
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
|
||||
|
||||
JsonNode jsonNodeBatch =
|
||||
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(batch).asJson().getBody();
|
||||
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
INDArray result = getNDArray(jsonNode);
|
||||
assertEquals(1, result.size(0));
|
||||
|
||||
INDArray batchResult = getNDArray(jsonNodeBatch);
|
||||
assertEquals(3, batchResult.size(0));
|
||||
|
||||
// System.out.println(array);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testImageServerMultipart() throws Exception {
|
||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
|
||||
.header("accept", "application/json")
|
||||
.field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile())
|
||||
.field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile())
|
||||
.field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile())
|
||||
.asJson().getBody();
|
||||
|
||||
|
||||
INDArray batchResult = getNDArray(jsonNode);
|
||||
assertEquals(3, batchResult.size(0));
|
||||
|
||||
// System.out.println(batchResult);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testImageServerSingleMultipart() throws Exception {
|
||||
File f = testDir.newFolder();
|
||||
File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f);
|
||||
|
||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
|
||||
.header("accept", "application/json")
|
||||
.field("file1", imgFile)
|
||||
.asJson().getBody();
|
||||
|
||||
|
||||
INDArray result = getNDArray(jsonNode);
|
||||
assertEquals(1, result.size(0));
|
||||
|
||||
// System.out.println(result);
|
||||
}
|
||||
|
||||
public INDArray getNDArray(JsonNode node) throws IOException {
|
||||
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
|
||||
}
|
||||
}
|
|
@ -1,168 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.datavec.spark.transform;
|
||||
|
||||
|
||||
import com.mashape.unirest.http.JsonNode;
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.datavec.api.transform.TransformProcess;
|
||||
import org.datavec.api.transform.schema.Schema;
|
||||
import org.datavec.image.transform.ImageTransformProcess;
|
||||
import org.datavec.spark.inference.server.SparkTransformServerChooser;
|
||||
import org.datavec.spark.inference.server.TransformDataType;
|
||||
import org.datavec.spark.inference.model.model.*;
|
||||
import org.junit.AfterClass;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.io.ClassPathResource;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.util.UUID;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class SparkTransformServerTest {
|
||||
private static SparkTransformServerChooser serverChooser;
|
||||
private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
|
||||
private static TransformProcess transformProcess =
|
||||
new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build();
|
||||
|
||||
private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json");
|
||||
private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json");
|
||||
|
||||
@BeforeClass
|
||||
public static void before() throws Exception {
|
||||
serverChooser = new SparkTransformServerChooser();
|
||||
|
||||
ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
|
||||
.scaleImageTransform(10).cropImageTransform(5).build();
|
||||
|
||||
FileUtils.write(imageTransformFile, imgTransformProcess.toJson());
|
||||
|
||||
FileUtils.write(csvTransformFile, transformProcess.toJson());
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
}
|
||||
|
||||
@AfterClass
|
||||
public static void after() throws Exception {
|
||||
imageTransformFile.deleteOnExit();
|
||||
csvTransformFile.deleteOnExit();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testImageServer() throws Exception {
|
||||
serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt",
|
||||
TransformDataType.IMAGE.toString()});
|
||||
|
||||
SingleImageRecord record =
|
||||
new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
||||
JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asJson().getBody();
|
||||
Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
BatchImageRecord batch = new BatchImageRecord();
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
|
||||
batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
|
||||
|
||||
JsonNode jsonNodeBatch =
|
||||
Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(batch).asJson().getBody();
|
||||
Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(batch)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
INDArray result = getNDArray(jsonNode);
|
||||
assertEquals(1, result.size(0));
|
||||
|
||||
INDArray batchResult = getNDArray(jsonNodeBatch);
|
||||
assertEquals(3, batchResult.size(0));
|
||||
|
||||
serverChooser.getSparkTransformServer().stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testCSVServer() throws Exception {
|
||||
serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt",
|
||||
TransformDataType.CSV.toString()});
|
||||
|
||||
String[] values = new String[] {"1.0", "2.0"};
|
||||
SingleCSVRecord record = new SingleCSVRecord(values);
|
||||
JsonNode jsonNode =
|
||||
Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(record).asJson().getBody();
|
||||
SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(SingleCSVRecord.class).getBody();
|
||||
|
||||
BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
|
||||
for (int i = 0; i < 3; i++)
|
||||
batchCSVRecord.add(singleCsvRecord);
|
||||
BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
|
||||
|
||||
Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json").body(record)
|
||||
.asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
|
||||
.header("accept", "application/json").header("Content-Type", "application/json")
|
||||
.body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
|
||||
|
||||
|
||||
serverChooser.getSparkTransformServer().stop();
|
||||
}
|
||||
|
||||
public INDArray getNDArray(JsonNode node) throws IOException {
|
||||
return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
|
||||
}
|
||||
}
|
|
@ -1,6 +0,0 @@
|
|||
play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
|
||||
play.modules.enabled += io.skymind.skil.service.PredictionModule
|
||||
play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
|
||||
play.server.pidfile.path=/tmp/RUNNING_PID
|
||||
|
||||
play.server.http.port = 9600
|
|
@ -1,68 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>datavec-spark-inference-parent</artifactId>
|
||||
<packaging>pom</packaging>
|
||||
|
||||
<name>datavec-spark-inference-parent</name>
|
||||
|
||||
<modules>
|
||||
<module>datavec-spark-inference-server</module>
|
||||
<module>datavec-spark-inference-client</module>
|
||||
<module>datavec-spark-inference-model</module>
|
||||
</modules>
|
||||
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.datavec</groupId>
|
||||
<artifactId>datavec-data-image</artifactId>
|
||||
<version>${datavec.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</dependencyManagement>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -45,7 +45,6 @@
|
|||
<module>datavec-data</module>
|
||||
<module>datavec-spark</module>
|
||||
<module>datavec-local</module>
|
||||
<module>datavec-spark-inference-parent</module>
|
||||
<module>datavec-jdbc</module>
|
||||
<module>datavec-excel</module>
|
||||
<module>datavec-arrow</module>
|
||||
|
|
|
@ -1,143 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-nearestneighbor-server</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>deeplearning4j-nearestneighbor-server</name>
|
||||
|
||||
<properties>
|
||||
<java.compile.version>1.8</java.compile.version>
|
||||
</properties>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-core</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.vertx</groupId>
|
||||
<artifactId>vertx-core</artifactId>
|
||||
<version>${vertx.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>io.vertx</groupId>
|
||||
<artifactId>vertx-web</artifactId>
|
||||
<version>${vertx.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.beust</groupId>
|
||||
<artifactId>jcommander</artifactId>
|
||||
<version>${jcommander.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-surefire-plugin</artifactId>
|
||||
<configuration>
|
||||
<argLine>-Dfile.encoding=UTF-8 -Xmx8g</argLine>
|
||||
<includes>
|
||||
<!-- Default setting only runs tests that start/end with "Test" -->
|
||||
<include>*.java</include>
|
||||
<include>**/*.java</include>
|
||||
</includes>
|
||||
</configuration>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.apache.maven.plugins</groupId>
|
||||
<artifactId>maven-compiler-plugin</artifactId>
|
||||
<configuration>
|
||||
<source>${java.compile.version}</source>
|
||||
<target>${java.compile.version}</target>
|
||||
</configuration>
|
||||
</plugin>
|
||||
</plugins>
|
||||
</build>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,67 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.server;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
|
||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
public class NearestNeighbor {
|
||||
private NearestNeighborRequest record;
|
||||
private VPTree tree;
|
||||
private INDArray points;
|
||||
|
||||
public List<NearestNeighborsResult> search() {
|
||||
INDArray input = points.slice(record.getInputIndex());
|
||||
List<NearestNeighborsResult> results = new ArrayList<>();
|
||||
if (input.isVector()) {
|
||||
List<DataPoint> add = new ArrayList<>();
|
||||
List<Double> distances = new ArrayList<>();
|
||||
tree.search(input, record.getK(), add, distances);
|
||||
|
||||
if (add.size() != distances.size()) {
|
||||
throw new IllegalStateException(
|
||||
String.format("add.size == %d != %d == distances.size",
|
||||
add.size(), distances.size()));
|
||||
}
|
||||
|
||||
for (int i=0; i<add.size(); i++) {
|
||||
results.add(new NearestNeighborsResult(add.get(i).getIndex(), distances.get(i)));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
return results;
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,278 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.server;
|
||||
|
||||
import com.beust.jcommander.JCommander;
|
||||
import com.beust.jcommander.Parameter;
|
||||
import com.beust.jcommander.ParameterException;
|
||||
import io.netty.handler.codec.http.HttpResponseStatus;
|
||||
import io.vertx.core.AbstractVerticle;
|
||||
import io.vertx.core.Vertx;
|
||||
import io.vertx.ext.web.Router;
|
||||
import io.vertx.ext.web.handler.BodyHandler;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.apache.commons.io.FileUtils;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
||||
import org.deeplearning4j.exception.DL4JInvalidInputException;
|
||||
import org.deeplearning4j.nearestneighbor.model.*;
|
||||
import org.deeplearning4j.nn.conf.serde.JsonMappers;
|
||||
import org.nd4j.linalg.api.buffer.DataBuffer;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.shape.Shape;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.NDArrayIndex;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
import org.nd4j.serde.binary.BinarySerde;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.*;
|
||||
|
||||
@Slf4j
|
||||
public class NearestNeighborsServer extends AbstractVerticle {
|
||||
|
||||
private static class RunArgs {
|
||||
@Parameter(names = {"--ndarrayPath"}, arity = 1, required = true)
|
||||
private String ndarrayPath = null;
|
||||
@Parameter(names = {"--labelsPath"}, arity = 1, required = false)
|
||||
private String labelsPath = null;
|
||||
@Parameter(names = {"--nearestNeighborsPort"}, arity = 1)
|
||||
private int port = 9000;
|
||||
@Parameter(names = {"--similarityFunction"}, arity = 1)
|
||||
private String similarityFunction = "euclidean";
|
||||
@Parameter(names = {"--invert"}, arity = 1)
|
||||
private boolean invert = false;
|
||||
}
|
||||
|
||||
private static RunArgs instanceArgs;
|
||||
private static NearestNeighborsServer instance;
|
||||
|
||||
public NearestNeighborsServer(){ }
|
||||
|
||||
public static NearestNeighborsServer getInstance(){
|
||||
return instance;
|
||||
}
|
||||
|
||||
public static void runMain(String... args) {
|
||||
RunArgs r = new RunArgs();
|
||||
JCommander jcmdr = new JCommander(r);
|
||||
|
||||
try {
|
||||
jcmdr.parse(args);
|
||||
} catch (ParameterException e) {
|
||||
log.error("Error in NearestNeighboursServer parameters", e);
|
||||
StringBuilder sb = new StringBuilder();
|
||||
jcmdr.usage(sb);
|
||||
log.error("Usage: {}", sb.toString());
|
||||
|
||||
//User provides invalid input -> print the usage info
|
||||
jcmdr.usage();
|
||||
if (r.ndarrayPath == null)
|
||||
log.error("Json path parameter is missing (null)");
|
||||
try {
|
||||
Thread.sleep(500);
|
||||
} catch (Exception e2) {
|
||||
}
|
||||
System.exit(1);
|
||||
}
|
||||
|
||||
instanceArgs = r;
|
||||
try {
|
||||
Vertx vertx = Vertx.vertx();
|
||||
vertx.deployVerticle(NearestNeighborsServer.class.getName());
|
||||
} catch (Throwable t){
|
||||
log.error("Error in NearestNeighboursServer run method",t);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void start() throws Exception {
|
||||
instance = this;
|
||||
|
||||
String[] pathArr = instanceArgs.ndarrayPath.split(",");
|
||||
//INDArray[] pointsArr = new INDArray[pathArr.length];
|
||||
// first of all we reading shapes of saved eariler files
|
||||
int rows = 0;
|
||||
int cols = 0;
|
||||
for (int i = 0; i < pathArr.length; i++) {
|
||||
DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
|
||||
|
||||
log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
|
||||
Shape.size(shape, 1));
|
||||
|
||||
if (Shape.rank(shape) != 2)
|
||||
throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
|
||||
|
||||
rows += Shape.size(shape, 0);
|
||||
|
||||
if (cols == 0)
|
||||
cols = Shape.size(shape, 1);
|
||||
else if (cols != Shape.size(shape, 1))
|
||||
throw new DL4JInvalidInputException(
|
||||
"NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
|
||||
}
|
||||
|
||||
final List<String> labels = new ArrayList<>();
|
||||
if (instanceArgs.labelsPath != null) {
|
||||
String[] labelsPathArr = instanceArgs.labelsPath.split(",");
|
||||
for (int i = 0; i < labelsPathArr.length; i++) {
|
||||
labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
|
||||
}
|
||||
}
|
||||
if (!labels.isEmpty() && labels.size() != rows)
|
||||
throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size()));
|
||||
|
||||
final INDArray points = Nd4j.createUninitialized(rows, cols);
|
||||
|
||||
int lastPosition = 0;
|
||||
for (int i = 0; i < pathArr.length; i++) {
|
||||
log.info("Loading chunk {} of {}", i + 1, pathArr.length);
|
||||
INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i]));
|
||||
|
||||
points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr);
|
||||
lastPosition += pointsArr.rows();
|
||||
|
||||
// let's ensure we don't bring too much stuff in next loop
|
||||
System.gc();
|
||||
}
|
||||
|
||||
VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
|
||||
|
||||
//Set play secret key, if required
|
||||
//http://www.playframework.com/documentation/latest/ApplicationSecret
|
||||
String crypto = System.getProperty("play.crypto.secret");
|
||||
if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
|
||||
byte[] newCrypto = new byte[1024];
|
||||
|
||||
new Random().nextBytes(newCrypto);
|
||||
|
||||
String base64 = Base64.getEncoder().encodeToString(newCrypto);
|
||||
System.setProperty("play.crypto.secret", base64);
|
||||
}
|
||||
|
||||
Router r = Router.router(vertx);
|
||||
r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
|
||||
createRoutes(r, labels, tree, points);
|
||||
|
||||
vertx.createHttpServer()
|
||||
.requestHandler(r)
|
||||
.listen(instanceArgs.port);
|
||||
}
|
||||
|
||||
private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){
|
||||
|
||||
r.post("/knn").handler(rc -> {
|
||||
try {
|
||||
String json = rc.getBodyAsJson().encode();
|
||||
NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
|
||||
|
||||
NearestNeighbor nearestNeighbor =
|
||||
NearestNeighbor.builder().points(points).record(record).tree(tree).build();
|
||||
|
||||
if (record == null) {
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||
return;
|
||||
}
|
||||
|
||||
NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
|
||||
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(results));
|
||||
return;
|
||||
} catch (Throwable e) {
|
||||
log.error("Error in POST /knn",e);
|
||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end("Error parsing request - " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
});
|
||||
|
||||
r.post("/knnnew").handler(rc -> {
|
||||
try {
|
||||
String json = rc.getBodyAsJson().encode();
|
||||
Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
|
||||
if (record == null) {
|
||||
rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
|
||||
return;
|
||||
}
|
||||
|
||||
INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
|
||||
List<DataPoint> results;
|
||||
List<Double> distances;
|
||||
|
||||
if (record.isForceFillK()) {
|
||||
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr);
|
||||
vpTreeFillSearch.search();
|
||||
results = vpTreeFillSearch.getResults();
|
||||
distances = vpTreeFillSearch.getDistances();
|
||||
} else {
|
||||
results = new ArrayList<>();
|
||||
distances = new ArrayList<>();
|
||||
tree.search(arr, record.getK(), results, distances);
|
||||
}
|
||||
|
||||
if (results.size() != distances.size()) {
|
||||
rc.response()
|
||||
.setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
|
||||
return;
|
||||
}
|
||||
|
||||
List<NearestNeighborsResult> nnResult = new ArrayList<>();
|
||||
for (int i=0; i<results.size(); i++) {
|
||||
if (!labels.isEmpty())
|
||||
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i), labels.get(results.get(i).getIndex())));
|
||||
else
|
||||
nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i)));
|
||||
}
|
||||
|
||||
NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build();
|
||||
String j = JsonMappers.getMapper().writeValueAsString(results2);
|
||||
rc.response()
|
||||
.putHeader("content-type", "application/json")
|
||||
.end(j);
|
||||
} catch (Throwable e) {
|
||||
log.error("Error in POST /knnnew",e);
|
||||
rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
|
||||
.end("Error parsing request - " + e.getMessage());
|
||||
return;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop the server
|
||||
*/
|
||||
public void stop() throws Exception {
|
||||
super.stop();
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
runMain(args);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,161 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.server;
|
||||
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
import org.deeplearning4j.clustering.vptree.VPTree;
|
||||
import org.deeplearning4j.clustering.vptree.VPTreeFillSearch;
|
||||
import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient;
|
||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
|
||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
|
||||
import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults;
|
||||
import org.junit.Rule;
|
||||
import org.junit.Test;
|
||||
import org.junit.rules.TemporaryFolder;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.serde.binary.BinarySerde;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.IOException;
|
||||
import java.net.ServerSocket;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Executor;
|
||||
import java.util.concurrent.Executors;
|
||||
|
||||
import static org.junit.Assert.assertEquals;
|
||||
|
||||
public class NearestNeighborTest extends BaseDL4JTest {
|
||||
|
||||
@Rule
|
||||
public TemporaryFolder testDir = new TemporaryFolder();
|
||||
|
||||
@Test
|
||||
public void testNearestNeighbor() {
|
||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
||||
INDArray arr = Nd4j.create(data);
|
||||
|
||||
VPTree vpTree = new VPTree(arr, false);
|
||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
||||
request.setK(2);
|
||||
request.setInputIndex(0);
|
||||
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
|
||||
List<NearestNeighborsResult> results = nearestNeighbor.search();
|
||||
assertEquals(1, results.get(0).getIndex());
|
||||
assertEquals(2, results.size());
|
||||
|
||||
assertEquals(1.0, results.get(0).getDistance(), 1e-4);
|
||||
assertEquals(4.0, results.get(1).getDistance(), 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNearestNeighborInverted() {
|
||||
double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
|
||||
INDArray arr = Nd4j.create(data);
|
||||
|
||||
VPTree vpTree = new VPTree(arr, true);
|
||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
||||
request.setK(2);
|
||||
request.setInputIndex(0);
|
||||
NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
|
||||
List<NearestNeighborsResult> results = nearestNeighbor.search();
|
||||
assertEquals(2, results.get(0).getIndex());
|
||||
assertEquals(2, results.size());
|
||||
|
||||
assertEquals(-4.0, results.get(0).getDistance(), 1e-4);
|
||||
assertEquals(-1.0, results.get(1).getDistance(), 1e-4);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void vpTreeTest() throws Exception {
|
||||
INDArray matrix = Nd4j.rand(new int[] {400,10});
|
||||
INDArray rowVector = matrix.getRow(70);
|
||||
INDArray resultArr = Nd4j.zeros(400,1);
|
||||
Executor executor = Executors.newSingleThreadExecutor();
|
||||
VPTree vpTree = new VPTree(matrix);
|
||||
System.out.println("Ran!");
|
||||
}
|
||||
|
||||
|
||||
|
||||
public static int getAvailablePort() {
|
||||
try {
|
||||
ServerSocket socket = new ServerSocket(0);
|
||||
try {
|
||||
return socket.getLocalPort();
|
||||
} finally {
|
||||
socket.close();
|
||||
}
|
||||
} catch (IOException e) {
|
||||
throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testServer() throws Exception {
|
||||
int localPort = getAvailablePort();
|
||||
Nd4j.getRandom().setSeed(7);
|
||||
INDArray rand = Nd4j.randn(10, 5);
|
||||
File writeToTmp = testDir.newFile();
|
||||
writeToTmp.deleteOnExit();
|
||||
BinarySerde.writeArrayToDisk(rand, writeToTmp);
|
||||
NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
|
||||
String.valueOf(localPort));
|
||||
|
||||
Thread.sleep(3000);
|
||||
|
||||
NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
|
||||
NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
|
||||
assertEquals(5, result.getResults().size());
|
||||
NearestNeighborsServer.getInstance().stop();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@Test
|
||||
public void testFullSearch() throws Exception {
|
||||
int numRows = 1000;
|
||||
int numCols = 100;
|
||||
int numNeighbors = 42;
|
||||
INDArray points = Nd4j.rand(numRows, numCols);
|
||||
VPTree tree = new VPTree(points);
|
||||
INDArray query = Nd4j.rand(new int[] {1, numCols});
|
||||
VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query);
|
||||
fillSearch.search();
|
||||
List<DataPoint> results = fillSearch.getResults();
|
||||
List<Double> distances = fillSearch.getDistances();
|
||||
assertEquals(numNeighbors, distances.size());
|
||||
assertEquals(numNeighbors, results.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDistances() {
|
||||
|
||||
INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}});
|
||||
INDArray record = Nd4j.create(new float[][]{{7, 6}});
|
||||
VPTree vpTree = new VPTree(indArray, "euclidean", false);
|
||||
VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record);
|
||||
vpTreeFillSearch.search();
|
||||
//System.out.println(vpTreeFillSearch.getResults());
|
||||
System.out.println(vpTreeFillSearch.getDistances());
|
||||
}
|
||||
}
|
|
@ -1,46 +0,0 @@
|
|||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<configuration>
|
||||
|
||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
||||
<file>logs/application.log</file>
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
||||
<encoder>
|
||||
<pattern> %logger{15} - %message%n%xException{5}
|
||||
</pattern>
|
||||
</encoder>
|
||||
</appender>
|
||||
|
||||
<logger name="org.deeplearning4j" level="INFO" />
|
||||
<logger name="org.datavec" level="INFO" />
|
||||
<logger name="org.nd4j" level="INFO" />
|
||||
|
||||
<root level="ERROR">
|
||||
<appender-ref ref="STDOUT" />
|
||||
<appender-ref ref="FILE" />
|
||||
</root>
|
||||
</configuration>
|
|
@ -1,60 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-nearestneighbors-client</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>deeplearning4j-nearestneighbors-client</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>com.mashape.unirest</groupId>
|
||||
<artifactId>unirest-java</artifactId>
|
||||
<version>${unirest.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,137 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.client;
|
||||
|
||||
import com.mashape.unirest.http.ObjectMapper;
|
||||
import com.mashape.unirest.http.Unirest;
|
||||
import com.mashape.unirest.request.HttpRequest;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.nearestneighbor.model.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.serde.base64.Nd4jBase64;
|
||||
import org.nd4j.shade.jackson.core.JsonProcessingException;
|
||||
|
||||
import java.io.IOException;
|
||||
|
||||
@AllArgsConstructor
|
||||
public class NearestNeighborsClient {
|
||||
|
||||
private String url;
|
||||
@Setter
|
||||
@Getter
|
||||
protected String authToken;
|
||||
|
||||
public NearestNeighborsClient(String url){
|
||||
this(url, null);
|
||||
}
|
||||
|
||||
static {
|
||||
// Only one time
|
||||
|
||||
Unirest.setObjectMapper(new ObjectMapper() {
|
||||
private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
|
||||
new org.nd4j.shade.jackson.databind.ObjectMapper();
|
||||
|
||||
public <T> T readValue(String value, Class<T> valueType) {
|
||||
try {
|
||||
return jacksonObjectMapper.readValue(value, valueType);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
|
||||
public String writeValue(Object value) {
|
||||
try {
|
||||
return jacksonObjectMapper.writeValueAsString(value);
|
||||
} catch (JsonProcessingException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Runs knn on the given index
|
||||
* with the given k (note that this is for data
|
||||
* already within the existing dataset not new data)
|
||||
* @param index the index of the
|
||||
* EXISTING ndarray
|
||||
* to run a search on
|
||||
* @param k the number of results
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public NearestNeighborsResults knn(int index, int k) throws Exception {
|
||||
NearestNeighborRequest request = new NearestNeighborRequest();
|
||||
request.setInputIndex(index);
|
||||
request.setK(k);
|
||||
val req = Unirest.post(url + "/knn");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(request);
|
||||
addAuthHeader(req);
|
||||
|
||||
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Run a k nearest neighbors search
|
||||
* on a NEW data point
|
||||
* @param k the number of results
|
||||
* to retrieve
|
||||
* @param arr the array to run the search on.
|
||||
* Note that this must be a row vector
|
||||
* @return
|
||||
* @throws Exception
|
||||
*/
|
||||
public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception {
|
||||
Base64NDArrayBody base64NDArrayBody =
|
||||
Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
|
||||
|
||||
val req = Unirest.post(url + "/knnnew");
|
||||
req.header("accept", "application/json")
|
||||
.header("Content-Type", "application/json").body(base64NDArrayBody);
|
||||
addAuthHeader(req);
|
||||
|
||||
NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add the specified authentication header to the specified HttpRequest
|
||||
*
|
||||
* @param request HTTP Request to add the authentication header to
|
||||
*/
|
||||
protected HttpRequest addAuthHeader(HttpRequest request) {
|
||||
if (authToken != null) {
|
||||
request.header("authorization", "Bearer " + authToken);
|
||||
}
|
||||
|
||||
return request;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>deeplearning4j-nearestneighbors-model</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>deeplearning4j-nearestneighbors-model</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.projectlombok</groupId>
|
||||
<artifactId>lombok</artifactId>
|
||||
<version>${lombok.version}</version>
|
||||
<scope>provided</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
@Builder
|
||||
public class Base64NDArrayBody implements Serializable {
|
||||
private String ndarray;
|
||||
private int k;
|
||||
private boolean forceFillK;
|
||||
}
|
|
@ -1,65 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
public class BatchRecord implements Serializable {
|
||||
private List<CSVRecord> records;
|
||||
|
||||
/**
|
||||
* Add a record
|
||||
* @param record
|
||||
*/
|
||||
public void add(CSVRecord record) {
|
||||
if (records == null)
|
||||
records = new ArrayList<>();
|
||||
records.add(record);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Return a batch record based on a dataset
|
||||
* @param dataSet the dataset to get the batch record for
|
||||
* @return the batch record
|
||||
*/
|
||||
public static BatchRecord fromDataSet(DataSet dataSet) {
|
||||
BatchRecord batchRecord = new BatchRecord();
|
||||
for (int i = 0; i < dataSet.numExamples(); i++) {
|
||||
batchRecord.add(CSVRecord.fromRow(dataSet.get(i)));
|
||||
}
|
||||
|
||||
return batchRecord;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,85 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.dataset.DataSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class CSVRecord implements Serializable {
|
||||
private String[] values;
|
||||
|
||||
/**
|
||||
* Instantiate a csv record from a vector
|
||||
* given either an input dataset and a
|
||||
* one hot matrix, the index will be appended to
|
||||
* the end of the record, or for regression
|
||||
* it will append all values in the labels
|
||||
* @param row the input vectors
|
||||
* @return the record from this {@link DataSet}
|
||||
*/
|
||||
public static CSVRecord fromRow(DataSet row) {
|
||||
if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
|
||||
throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
|
||||
if (!row.getLabels().isVector() && !row.getLabels().isScalar())
|
||||
throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
|
||||
//classification
|
||||
CSVRecord record;
|
||||
int idx = 0;
|
||||
if (row.getLabels().sumNumber().doubleValue() == 1.0) {
|
||||
String[] values = new String[row.getFeatures().columns() + 1];
|
||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
||||
}
|
||||
int maxIdx = 0;
|
||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
||||
if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
|
||||
maxIdx = i;
|
||||
}
|
||||
}
|
||||
|
||||
values[idx++] = String.valueOf(maxIdx);
|
||||
record = new CSVRecord(values);
|
||||
}
|
||||
//regression (any number of values)
|
||||
else {
|
||||
String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
|
||||
for (int i = 0; i < row.getFeatures().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
|
||||
}
|
||||
for (int i = 0; i < row.getLabels().length(); i++) {
|
||||
values[idx++] = String.valueOf(row.getLabels().getDouble(i));
|
||||
}
|
||||
|
||||
|
||||
record = new CSVRecord(values);
|
||||
|
||||
}
|
||||
return record;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,32 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class NearestNeighborRequest implements Serializable {
|
||||
private int k;
|
||||
private int inputIndex;
|
||||
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
@Data
|
||||
@AllArgsConstructor
|
||||
@NoArgsConstructor
|
||||
public class NearestNeighborsResult {
|
||||
public NearestNeighborsResult(int index, double distance) {
|
||||
this(index, distance, null);
|
||||
}
|
||||
|
||||
private int index;
|
||||
private double distance;
|
||||
private String label;
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.nearestneighbor.model;
|
||||
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.List;
|
||||
|
||||
@Data
|
||||
@Builder
|
||||
@NoArgsConstructor
|
||||
@AllArgsConstructor
|
||||
public class NearestNeighborsResults implements Serializable {
|
||||
private List<NearestNeighborsResult> results;
|
||||
|
||||
}
|
|
@ -1,103 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<!--
|
||||
~ /* ******************************************************************************
|
||||
~ *
|
||||
~ *
|
||||
~ * This program and the accompanying materials are made available under the
|
||||
~ * terms of the Apache License, Version 2.0 which is available at
|
||||
~ * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
~ *
|
||||
~ * See the NOTICE file distributed with this work for additional
|
||||
~ * information regarding copyright ownership.
|
||||
~ * Unless required by applicable law or agreed to in writing, software
|
||||
~ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
~ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
~ * License for the specific language governing permissions and limitations
|
||||
~ * under the License.
|
||||
~ *
|
||||
~ * SPDX-License-Identifier: Apache-2.0
|
||||
~ ******************************************************************************/
|
||||
-->
|
||||
|
||||
<project xmlns="http://maven.apache.org/POM/4.0.0"
|
||||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
||||
|
||||
<modelVersion>4.0.0</modelVersion>
|
||||
|
||||
<parent>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nearestneighbors-parent</artifactId>
|
||||
<version>1.0.0-SNAPSHOT</version>
|
||||
</parent>
|
||||
|
||||
<artifactId>nearestneighbor-core</artifactId>
|
||||
<packaging>jar</packaging>
|
||||
|
||||
<name>nearestneighbor-core</name>
|
||||
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-api</artifactId>
|
||||
<version>${nd4j.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>ch.qos.logback</groupId>
|
||||
<artifactId>logback-classic</artifactId>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-nn</artifactId>
|
||||
<version>${project.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-datasets</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>joda-time</groupId>
|
||||
<artifactId>joda-time</artifactId>
|
||||
<version>2.10.3</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.deeplearning4j</groupId>
|
||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
<profiles>
|
||||
<profile>
|
||||
<id>test-nd4j-native</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-native</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
<profile>
|
||||
<id>test-nd4j-cuda-11.0</id>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
<groupId>org.nd4j</groupId>
|
||||
<artifactId>nd4j-cuda-11.0</artifactId>
|
||||
<version>${project.version}</version>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
</profile>
|
||||
</profiles>
|
||||
</project>
|
|
@ -1,218 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.algorithm;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
||||
import org.deeplearning4j.clustering.cluster.ClusterUtils;
|
||||
import org.deeplearning4j.clustering.cluster.Point;
|
||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
import org.deeplearning4j.clustering.iteration.IterationInfo;
|
||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
|
||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
|
||||
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
|
||||
import org.deeplearning4j.clustering.util.MultiThreadUtils;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
@Slf4j
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable {
|
||||
|
||||
private static final long serialVersionUID = 338231277453149972L;
|
||||
|
||||
private ClusteringStrategy clusteringStrategy;
|
||||
private IterationHistory iterationHistory;
|
||||
private int currentIteration = 0;
|
||||
private ClusterSet clusterSet;
|
||||
private List<Point> initialPoints;
|
||||
private transient ExecutorService exec;
|
||||
private boolean useKmeansPlusPlus;
|
||||
|
||||
|
||||
protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||
this.clusteringStrategy = clusteringStrategy;
|
||||
this.exec = MultiThreadUtils.newExecutorService();
|
||||
this.useKmeansPlusPlus = useKmeansPlusPlus;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusteringStrategy
|
||||
* @return
|
||||
*/
|
||||
public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
|
||||
return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param points
|
||||
* @return
|
||||
*/
|
||||
public ClusterSet applyTo(List<Point> points) {
|
||||
resetState(points);
|
||||
initClusters(useKmeansPlusPlus);
|
||||
iterations();
|
||||
return clusterSet;
|
||||
}
|
||||
|
||||
private void resetState(List<Point> points) {
|
||||
this.iterationHistory = new IterationHistory();
|
||||
this.currentIteration = 0;
|
||||
this.clusterSet = null;
|
||||
this.initialPoints = points;
|
||||
}
|
||||
|
||||
/** Run clustering iterations until a
|
||||
* termination condition is hit.
|
||||
* This is done by first classifying all points,
|
||||
* and then updating cluster centers based on
|
||||
* those classified points
|
||||
*/
|
||||
private void iterations() {
|
||||
int iterationCount = 0;
|
||||
while ((clusteringStrategy.getTerminationCondition() != null
|
||||
&& !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory))
|
||||
|| iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
|
||||
currentIteration++;
|
||||
removePoints();
|
||||
classifyPoints();
|
||||
applyClusteringStrategy();
|
||||
log.trace("Completed clustering iteration {}", ++iterationCount);
|
||||
}
|
||||
}
|
||||
|
||||
protected void classifyPoints() {
|
||||
//Classify points. This also adds each point to the ClusterSet
|
||||
ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec);
|
||||
//Update the cluster centers, based on the points within each cluster
|
||||
ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec);
|
||||
iterationHistory.getIterationsInfos().put(currentIteration,
|
||||
new IterationInfo(currentIteration, clusterSetInfo));
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the
|
||||
* cluster centers at random
|
||||
*/
|
||||
protected void initClusters(boolean kMeansPlusPlus) {
|
||||
log.info("Generating initial clusters");
|
||||
List<Point> points = new ArrayList<>(initialPoints);
|
||||
|
||||
//Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly)
|
||||
val random = Nd4j.getRandom();
|
||||
Distance distanceFn = clusteringStrategy.getDistanceFunction();
|
||||
int initialClusterCount = clusteringStrategy.getInitialClusterCount();
|
||||
clusterSet = new ClusterSet(distanceFn,
|
||||
clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()});
|
||||
clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size())));
|
||||
|
||||
|
||||
//dxs: distances between
|
||||
// each point and nearest cluster to that point
|
||||
INDArray dxs = Nd4j.create(points.size());
|
||||
dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE);
|
||||
|
||||
//Generate the initial cluster centers, by randomly selecting a point between 0 and max distance
|
||||
//Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
|
||||
while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
|
||||
dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
|
||||
double summed = Nd4j.sum(dxs).getDouble(0);
|
||||
double r = kMeansPlusPlus ? random.nextDouble() * summed:
|
||||
random.nextFloat() * dxs.maxNumber().doubleValue();
|
||||
|
||||
for (int i = 0; i < dxs.length(); i++) {
|
||||
double distance = dxs.getDouble(i);
|
||||
Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
|
||||
"function must return values >= 0, got distance %s for function s", distance, distanceFn);
|
||||
if (dxs.getDouble(i) >= r) {
|
||||
clusterSet.addNewClusterWithCenter(points.remove(i));
|
||||
dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet);
|
||||
iterationHistory.getIterationsInfos().put(currentIteration,
|
||||
new IterationInfo(currentIteration, initialClusterSetInfo));
|
||||
}
|
||||
|
||||
|
||||
protected void applyClusteringStrategy() {
|
||||
if (!isStrategyApplicableNow())
|
||||
return;
|
||||
|
||||
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
|
||||
if (!clusteringStrategy.isAllowEmptyClusters()) {
|
||||
int removedCount = removeEmptyClusters(clusterSetInfo);
|
||||
if (removedCount > 0) {
|
||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
|
||||
|
||||
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT)
|
||||
&& clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) {
|
||||
int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo,
|
||||
clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec);
|
||||
if (splitCount > 0)
|
||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION))
|
||||
optimize();
|
||||
}
|
||||
|
||||
protected void optimize() {
|
||||
ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
|
||||
OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy;
|
||||
boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec);
|
||||
iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied);
|
||||
}
|
||||
|
||||
private boolean isStrategyApplicableNow() {
|
||||
return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0
|
||||
&& clusteringStrategy.isOptimizationApplicableNow(iterationHistory);
|
||||
}
|
||||
|
||||
protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
|
||||
List<Cluster> removedClusters = clusterSet.removeEmptyClusters();
|
||||
clusterSetInfo.removeClusterInfos(removedClusters);
|
||||
return removedClusters.size();
|
||||
}
|
||||
|
||||
protected void removePoints() {
|
||||
clusterSet.removePoints();
|
||||
}
|
||||
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.algorithm;
|
||||
|
||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
||||
import org.deeplearning4j.clustering.cluster.Point;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
public interface ClusteringAlgorithm {
|
||||
|
||||
/**
|
||||
* Apply a clustering
|
||||
* algorithm for a given result
|
||||
* @param points
|
||||
* @return
|
||||
*/
|
||||
ClusterSet applyTo(List<Point> points);
|
||||
|
||||
}
|
|
@ -1,41 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.algorithm;
|
||||
|
||||
public enum Distance {
|
||||
EUCLIDEAN("euclidean"),
|
||||
COSINE_DISTANCE("cosinedistance"),
|
||||
COSINE_SIMILARITY("cosinesimilarity"),
|
||||
MANHATTAN("manhattan"),
|
||||
DOT("dot"),
|
||||
JACCARD("jaccard"),
|
||||
HAMMING("hamming");
|
||||
|
||||
private String functionName;
|
||||
private Distance(String name) {
|
||||
functionName = name;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return functionName;
|
||||
}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
||||
import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
public class CentersHolder {
|
||||
private INDArray centers;
|
||||
private long index = 0;
|
||||
|
||||
protected transient ReduceOp op;
|
||||
protected ArgMin imin;
|
||||
protected transient INDArray distances;
|
||||
protected transient INDArray argMin;
|
||||
|
||||
private long rows, cols;
|
||||
|
||||
public CentersHolder(long rows, long cols) {
|
||||
this.rows = rows;
|
||||
this.cols = cols;
|
||||
}
|
||||
|
||||
public INDArray getCenters() {
|
||||
return this.centers;
|
||||
}
|
||||
|
||||
public synchronized void addCenter(INDArray pointView) {
|
||||
if (centers == null)
|
||||
this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols});
|
||||
|
||||
centers.putRow(index++, pointView);
|
||||
}
|
||||
|
||||
public synchronized Pair<Double, Long> getCenterByMinDistance(Point point, Distance distanceFunction) {
|
||||
if (distances == null)
|
||||
distances = Nd4j.create(centers.dataType(), centers.rows());
|
||||
|
||||
if (argMin == null)
|
||||
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
|
||||
|
||||
if (op == null) {
|
||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
||||
imin = new ArgMin(distances, argMin);
|
||||
op.setZ(distances);
|
||||
}
|
||||
|
||||
op.setY(point.getArray());
|
||||
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().exec(imin);
|
||||
|
||||
Pair<Double, Long> result = new Pair<>();
|
||||
result.setFirst(distances.getDouble(argMin.getLong(0)));
|
||||
result.setSecond(argMin.getLong(0));
|
||||
return result;
|
||||
}
|
||||
|
||||
public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) {
|
||||
if (distances == null)
|
||||
distances = Nd4j.create(centers.dataType(), centers.rows());
|
||||
|
||||
if (argMin == null)
|
||||
argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
|
||||
|
||||
if (op == null) {
|
||||
op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
|
||||
imin = new ArgMin(distances, argMin);
|
||||
op.setZ(distances);
|
||||
}
|
||||
|
||||
op.setY(point.getArray());
|
||||
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
Nd4j.getExecutioner().exec(imin);
|
||||
|
||||
System.out.println(distances);
|
||||
return distances;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,150 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
@Data
|
||||
public class Cluster implements Serializable {
|
||||
|
||||
private String id = UUID.randomUUID().toString();
|
||||
private String label;
|
||||
|
||||
private Point center;
|
||||
private List<Point> points = Collections.synchronizedList(new ArrayList<Point>());
|
||||
private boolean inverse = false;
|
||||
private Distance distanceFunction;
|
||||
|
||||
public Cluster() {
|
||||
super();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param center
|
||||
* @param distanceFunction
|
||||
*/
|
||||
public Cluster(Point center, Distance distanceFunction) {
|
||||
this(center, false, distanceFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param center
|
||||
* @param distanceFunction
|
||||
*/
|
||||
public Cluster(Point center, boolean inverse, Distance distanceFunction) {
|
||||
this.distanceFunction = distanceFunction;
|
||||
this.inverse = inverse;
|
||||
setCenter(center);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the distance to the given
|
||||
* point from the cluster
|
||||
* @param point the point to get the distance for
|
||||
* @return
|
||||
*/
|
||||
public double getDistanceToCenter(Point point) {
|
||||
return Nd4j.getExecutioner().execAndReturn(
|
||||
ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray()))
|
||||
.getFinalResult().doubleValue();
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a point to the cluster
|
||||
* @param point
|
||||
*/
|
||||
public void addPoint(Point point) {
|
||||
addPoint(point, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Add a point to the cluster
|
||||
* @param point the point to add
|
||||
* @param moveClusterCenter whether to update
|
||||
* the cluster centroid or not
|
||||
*/
|
||||
public void addPoint(Point point, boolean moveClusterCenter) {
|
||||
if (moveClusterCenter) {
|
||||
if (isInverse()) {
|
||||
center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1);
|
||||
} else {
|
||||
center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1);
|
||||
}
|
||||
}
|
||||
|
||||
getPoints().add(point);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear out the ponits
|
||||
*/
|
||||
public void removePoints() {
|
||||
if (getPoints() != null)
|
||||
getPoints().clear();
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the cluster is empty or not
|
||||
* @return
|
||||
*/
|
||||
public boolean isEmpty() {
|
||||
return points == null || points.isEmpty();
|
||||
}
|
||||
|
||||
/**
|
||||
* Return the point with the given id
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public Point getPoint(String id) {
|
||||
for (Point point : points)
|
||||
if (id.equals(point.getId()))
|
||||
return point;
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove the point and return it
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public Point removePoint(String id) {
|
||||
Point removePoint = null;
|
||||
for (Point point : points)
|
||||
if (id.equals(point.getId()))
|
||||
removePoint = point;
|
||||
if (removePoint != null)
|
||||
points.remove(removePoint);
|
||||
return removePoint;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,259 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import lombok.Data;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
|
||||
@Data
|
||||
public class ClusterSet implements Serializable {
|
||||
|
||||
private Distance distanceFunction;
|
||||
private List<Cluster> clusters;
|
||||
private CentersHolder centersHolder;
|
||||
private Map<String, String> pointDistribution;
|
||||
private boolean inverse;
|
||||
|
||||
public ClusterSet(boolean inverse) {
|
||||
this(null, inverse, null);
|
||||
}
|
||||
|
||||
public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) {
|
||||
this.distanceFunction = distanceFunction;
|
||||
this.inverse = inverse;
|
||||
this.clusters = Collections.synchronizedList(new ArrayList<Cluster>());
|
||||
this.pointDistribution = Collections.synchronizedMap(new HashMap<String, String>());
|
||||
if (shape != null)
|
||||
this.centersHolder = new CentersHolder(shape[0], shape[1]);
|
||||
}
|
||||
|
||||
|
||||
public boolean isInverse() {
|
||||
return inverse;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param center
|
||||
* @return
|
||||
*/
|
||||
public Cluster addNewClusterWithCenter(Point center) {
|
||||
Cluster newCluster = new Cluster(center, distanceFunction);
|
||||
getClusters().add(newCluster);
|
||||
setPointLocation(center, newCluster);
|
||||
centersHolder.addCenter(center.getArray());
|
||||
return newCluster;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
public PointClassification classifyPoint(Point point) {
|
||||
return classifyPoint(point, true);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param points
|
||||
*/
|
||||
public void classifyPoints(List<Point> points) {
|
||||
classifyPoints(points, true);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param points
|
||||
* @param moveClusterCenter
|
||||
*/
|
||||
public void classifyPoints(List<Point> points, boolean moveClusterCenter) {
|
||||
for (Point point : points)
|
||||
classifyPoint(point, moveClusterCenter);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param point
|
||||
* @param moveClusterCenter
|
||||
* @return
|
||||
*/
|
||||
public PointClassification classifyPoint(Point point, boolean moveClusterCenter) {
|
||||
Pair<Cluster, Double> nearestCluster = nearestCluster(point);
|
||||
Cluster newCluster = nearestCluster.getKey();
|
||||
boolean locationChange = isPointLocationChange(point, newCluster);
|
||||
addPointToCluster(point, newCluster, moveClusterCenter);
|
||||
return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange);
|
||||
}
|
||||
|
||||
private boolean isPointLocationChange(Point point, Cluster newCluster) {
|
||||
if (!getPointDistribution().containsKey(point.getId()))
|
||||
return true;
|
||||
return !getPointDistribution().get(point.getId()).equals(newCluster.getId());
|
||||
}
|
||||
|
||||
private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) {
|
||||
cluster.addPoint(point, moveClusterCenter);
|
||||
setPointLocation(point, cluster);
|
||||
}
|
||||
|
||||
private void setPointLocation(Point point, Cluster cluster) {
|
||||
pointDistribution.put(point.getId(), cluster.getId());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
public Pair<Cluster, Double> nearestCluster(Point point) {
|
||||
|
||||
/*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE;
|
||||
|
||||
double currentDistance;
|
||||
for (Cluster cluster : getClusters()) {
|
||||
currentDistance = cluster.getDistanceToCenter(point);
|
||||
if (isInverse()) {
|
||||
if (currentDistance > minDistance) {
|
||||
minDistance = currentDistance;
|
||||
nearestCluster = cluster;
|
||||
}
|
||||
} else {
|
||||
if (currentDistance < minDistance) {
|
||||
minDistance = currentDistance;
|
||||
nearestCluster = cluster;
|
||||
}
|
||||
}
|
||||
|
||||
}*/
|
||||
|
||||
Pair<Double, Long> nearestCenterData = centersHolder.
|
||||
getCenterByMinDistance(point, distanceFunction);
|
||||
Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue());
|
||||
double minDistance = nearestCenterData.getFirst();
|
||||
return Pair.of(nearestCluster, minDistance);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param m1
|
||||
* @param m2
|
||||
* @return
|
||||
*/
|
||||
public double getDistance(Point m1, Point m2) {
|
||||
return Nd4j.getExecutioner()
|
||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray()))
|
||||
.getFinalResult().doubleValue();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param point
|
||||
* @return
|
||||
*/
|
||||
/*public double getDistanceFromNearestCluster(Point point) {
|
||||
return nearestCluster(point).getValue();
|
||||
}*/
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterId
|
||||
* @return
|
||||
*/
|
||||
public String getClusterCenterId(String clusterId) {
|
||||
Point clusterCenter = getClusterCenter(clusterId);
|
||||
return clusterCenter == null ? null : clusterCenter.getId();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterId
|
||||
* @return
|
||||
*/
|
||||
public Point getClusterCenter(String clusterId) {
|
||||
Cluster cluster = getCluster(clusterId);
|
||||
return cluster == null ? null : cluster.getCenter();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param id
|
||||
* @return
|
||||
*/
|
||||
public Cluster getCluster(String id) {
|
||||
for (int i = 0, j = clusters.size(); i < j; i++)
|
||||
if (id.equals(clusters.get(i).getId()))
|
||||
return clusters.get(i);
|
||||
return null;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public int getClusterCount() {
|
||||
return getClusters() == null ? 0 : getClusters().size();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public void removePoints() {
|
||||
for (Cluster cluster : getClusters())
|
||||
cluster.removePoints();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param count
|
||||
* @return
|
||||
*/
|
||||
public List<Cluster> getMostPopulatedClusters(int count) {
|
||||
List<Cluster> mostPopulated = new ArrayList<>(clusters);
|
||||
Collections.sort(mostPopulated, new Comparator<Cluster>() {
|
||||
public int compare(Cluster o1, Cluster o2) {
|
||||
return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
|
||||
}
|
||||
});
|
||||
return mostPopulated.subList(0, count);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public List<Cluster> removeEmptyClusters() {
|
||||
List<Cluster> emptyClusters = new ArrayList<>();
|
||||
for (Cluster cluster : clusters)
|
||||
if (cluster.isEmpty())
|
||||
emptyClusters.add(cluster);
|
||||
clusters.removeAll(emptyClusters);
|
||||
return emptyClusters;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,531 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.NoArgsConstructor;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import lombok.val;
|
||||
import org.apache.commons.lang3.ArrayUtils;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.info.ClusterInfo;
|
||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
|
||||
import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
|
||||
import org.deeplearning4j.clustering.util.MathUtils;
|
||||
import org.deeplearning4j.clustering.util.MultiThreadUtils;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
@NoArgsConstructor(access = AccessLevel.PRIVATE)
|
||||
@Slf4j
|
||||
public class ClusterUtils {
|
||||
|
||||
/** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */
|
||||
public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points,
|
||||
ExecutorService executorService) {
|
||||
final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
|
||||
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (final Point point : points) {
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
PointClassification result = classifyPoint(clusterSet, point);
|
||||
if (result.isNewLocation())
|
||||
clusterSetInfo.getPointLocationChange().incrementAndGet();
|
||||
clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter()
|
||||
.put(point.getId(), result.getDistanceFromCenter());
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error classifying point", t);
|
||||
}
|
||||
// }
|
||||
}
|
||||
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
return clusterSetInfo;
|
||||
}
|
||||
|
||||
public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
|
||||
return clusterSet.classifyPoint(point, false);
|
||||
}
|
||||
|
||||
public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
||||
ExecutorService executorService) {
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
int nClusters = clusterSet.getClusterCount();
|
||||
for (int i = 0; i < nClusters; i++) {
|
||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
||||
refreshClusterCenter(cluster, clusterInfo);
|
||||
deriveClusterInfoDistanceStatistics(clusterInfo);
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error refreshing cluster centers", t);
|
||||
}
|
||||
// }
|
||||
//});
|
||||
}
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
}
|
||||
|
||||
public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
|
||||
int pointsCount = cluster.getPoints().size();
|
||||
if (pointsCount == 0)
|
||||
return;
|
||||
Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length()));
|
||||
for (Point point : cluster.getPoints()) {
|
||||
INDArray arr = point.getArray();
|
||||
if (cluster.isInverse())
|
||||
center.getArray().subi(arr);
|
||||
else
|
||||
center.getArray().addi(arr);
|
||||
}
|
||||
center.getArray().divi(pointsCount);
|
||||
cluster.setCenter(center);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param info
|
||||
*/
|
||||
public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
|
||||
int pointCount = info.getPointDistancesFromCenter().size();
|
||||
if (pointCount == 0)
|
||||
return;
|
||||
|
||||
double[] distances =
|
||||
ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {}));
|
||||
double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances);
|
||||
double total = MathUtils.sum(distances);
|
||||
info.setMaxPointDistanceFromCenter(max);
|
||||
info.setTotalPointDistanceFromCenter(total);
|
||||
info.setAveragePointDistanceFromCenter(total / pointCount);
|
||||
info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param points
|
||||
* @param previousDxs
|
||||
* @param executorService
|
||||
* @return
|
||||
*/
|
||||
public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet,
|
||||
final List<Point> points, INDArray previousDxs, ExecutorService executorService) {
|
||||
final int pointsCount = points.size();
|
||||
final INDArray dxs = Nd4j.create(pointsCount);
|
||||
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
||||
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
final int i2 = i;
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
Point point = points.get(i2);
|
||||
double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
|
||||
: Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||
dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error computing squared distance from nearest cluster", t);
|
||||
}
|
||||
// }
|
||||
//});
|
||||
|
||||
}
|
||||
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
double previousMinDistance = previousDxs.getDouble(i);
|
||||
if (clusterSet.isInverse()) {
|
||||
if (dxs.getDouble(i) < previousMinDistance) {
|
||||
|
||||
dxs.putScalar(i, previousMinDistance);
|
||||
}
|
||||
} else if (dxs.getDouble(i) > previousMinDistance)
|
||||
dxs.putScalar(i, previousMinDistance);
|
||||
}
|
||||
|
||||
return dxs;
|
||||
}
|
||||
|
||||
public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
|
||||
final List<Point> points, INDArray previousDxs) {
|
||||
final int pointsCount = points.size();
|
||||
final INDArray dxs = Nd4j.create(pointsCount);
|
||||
final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
|
||||
|
||||
Double sum = new Double(0);
|
||||
for (int i = 0; i < pointsCount; i++) {
|
||||
|
||||
Point point = points.get(i);
|
||||
double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
|
||||
sum += dist;
|
||||
dxs.putScalar(i, sum);
|
||||
}
|
||||
|
||||
return dxs;
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @return
|
||||
*/
|
||||
public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
|
||||
ExecutorService executor = MultiThreadUtils.newExecutorService();
|
||||
ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor);
|
||||
executor.shutdownNow();
|
||||
return info;
|
||||
}
|
||||
|
||||
public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
|
||||
final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
|
||||
int clusterCount = clusterSet.getClusterCount();
|
||||
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (int i = 0; i < clusterCount; i++) {
|
||||
final Cluster cluster = clusterSet.getClusters().get(i);
|
||||
//tasks.add(new Runnable() {
|
||||
// public void run() {
|
||||
try {
|
||||
info.getClustersInfos().put(cluster.getId(),
|
||||
computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error computing cluster set info", t);
|
||||
}
|
||||
//}
|
||||
//});
|
||||
}
|
||||
|
||||
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
|
||||
//tasks = new ArrayList<>();
|
||||
for (int i = 0; i < clusterCount; i++) {
|
||||
final int clusterIdx = i;
|
||||
final Cluster fromCluster = clusterSet.getClusters().get(i);
|
||||
//tasks.add(new Runnable() {
|
||||
//public void run() {
|
||||
try {
|
||||
for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
|
||||
Cluster toCluster = clusterSet.getClusters().get(k);
|
||||
double distance = Nd4j.getExecutioner()
|
||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(
|
||||
clusterSet.getDistanceFunction(),
|
||||
fromCluster.getCenter().getArray(),
|
||||
toCluster.getCenter().getArray()))
|
||||
.getFinalResult().doubleValue();
|
||||
info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(),
|
||||
distance);
|
||||
}
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error computing distances", t);
|
||||
}
|
||||
// }
|
||||
//});
|
||||
|
||||
}
|
||||
|
||||
//MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param cluster
|
||||
* @param distanceFunction
|
||||
* @return
|
||||
*/
|
||||
public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) {
|
||||
ClusterInfo info = new ClusterInfo(cluster.isInverse(), true);
|
||||
for (int i = 0, j = cluster.getPoints().size(); i < j; i++) {
|
||||
Point point = cluster.getPoints().get(i);
|
||||
//shouldn't need to inverse here. other parts of
|
||||
//the code should interpret the "distance" or score here
|
||||
double distance = Nd4j.getExecutioner()
|
||||
.execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction,
|
||||
cluster.getCenter().getArray(), point.getArray()))
|
||||
.getFinalResult().doubleValue();
|
||||
info.getPointDistancesFromCenter().put(point.getId(), distance);
|
||||
double diff = info.getTotalPointDistanceFromCenter() + distance;
|
||||
info.setTotalPointDistanceFromCenter(diff);
|
||||
}
|
||||
|
||||
if (!cluster.getPoints().isEmpty())
|
||||
info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size());
|
||||
return info;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param optimization
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param executor
|
||||
* @return
|
||||
*/
|
||||
public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet,
|
||||
ClusterSetInfo clusterSetInfo, ExecutorService executor) {
|
||||
|
||||
if (optimization.isClusteringOptimizationType(
|
||||
ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
|
||||
int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet,
|
||||
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
|
||||
return splitCount > 0;
|
||||
}
|
||||
|
||||
if (optimization.isClusteringOptimizationType(
|
||||
ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
|
||||
int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet,
|
||||
clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
|
||||
return splitCount > 0;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param info
|
||||
* @param count
|
||||
* @return
|
||||
*/
|
||||
public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info,
|
||||
int count) {
|
||||
List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters());
|
||||
Collections.sort(clusters, new Comparator<Cluster>() {
|
||||
public int compare(Cluster o1, Cluster o2) {
|
||||
Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
|
||||
Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
|
||||
int comp = o1TotalDistance.compareTo(o2TotalDistance);
|
||||
return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp;
|
||||
}
|
||||
});
|
||||
|
||||
return clusters.subList(0, count);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param info
|
||||
* @param maximumAverageDistance
|
||||
* @return
|
||||
*/
|
||||
public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
|
||||
final ClusterSetInfo info, double maximumAverageDistance) {
|
||||
List<Cluster> clusters = new ArrayList<>();
|
||||
for (Cluster cluster : clusterSet.getClusters()) {
|
||||
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
|
||||
if (clusterInfo != null) {
|
||||
//distances
|
||||
if (clusterInfo.isInverse()) {
|
||||
if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance)
|
||||
clusters.add(cluster);
|
||||
} else {
|
||||
if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)
|
||||
clusters.add(cluster);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
return clusters;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param info
|
||||
* @param maximumDistance
|
||||
* @return
|
||||
*/
|
||||
public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
|
||||
final ClusterSetInfo info, double maximumDistance) {
|
||||
List<Cluster> clusters = new ArrayList<>();
|
||||
for (Cluster cluster : clusterSet.getClusters()) {
|
||||
ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
|
||||
if (clusterInfo != null) {
|
||||
if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) {
|
||||
clusters.add(cluster);
|
||||
} else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) {
|
||||
clusters.add(cluster);
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
return clusters;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param count
|
||||
* @param executorService
|
||||
* @return
|
||||
*/
|
||||
public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
|
||||
ExecutorService executorService) {
|
||||
List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
|
||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
|
||||
return clustersToSplit.size();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param maxWithinClusterDistance
|
||||
* @param executorService
|
||||
* @return
|
||||
*/
|
||||
public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet,
|
||||
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
|
||||
List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
|
||||
maxWithinClusterDistance);
|
||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
|
||||
return clustersToSplit.size();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param maxWithinClusterDistance
|
||||
* @param executorService
|
||||
* @return
|
||||
*/
|
||||
public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet,
|
||||
ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
|
||||
List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
|
||||
maxWithinClusterDistance);
|
||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
|
||||
return clustersToSplit.size();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param count
|
||||
* @param executorService
|
||||
*/
|
||||
public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
|
||||
ExecutorService executorService) {
|
||||
List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count);
|
||||
splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param clusters
|
||||
* @param maxDistance
|
||||
* @param executorService
|
||||
*/
|
||||
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
||||
List<Cluster> clusters, final double maxDistance, ExecutorService executorService) {
|
||||
final Random random = new Random();
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (final Cluster cluster : clusters) {
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
try {
|
||||
ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
|
||||
List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
|
||||
int rank = Math.min(fartherPoints.size(), 3);
|
||||
String pointId = fartherPoints.get(random.nextInt(rank));
|
||||
Point point = cluster.removePoint(pointId);
|
||||
clusterSet.addNewClusterWithCenter(point);
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error splitting clusters", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param clusterSetInfo
|
||||
* @param clusters
|
||||
* @param executorService
|
||||
*/
|
||||
public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
|
||||
List<Cluster> clusters, ExecutorService executorService) {
|
||||
final Random random = new Random();
|
||||
List<Runnable> tasks = new ArrayList<>();
|
||||
for (final Cluster cluster : clusters) {
|
||||
tasks.add(new Runnable() {
|
||||
public void run() {
|
||||
try {
|
||||
Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
|
||||
clusterSet.addNewClusterWithCenter(point);
|
||||
} catch (Throwable t) {
|
||||
log.warn("Error Splitting clusters (2)", t);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
MultiThreadUtils.parallelTasks(tasks, executorService);
|
||||
}
|
||||
|
||||
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){
|
||||
val op = createDistanceFunctionOp(distanceFunction, x, y);
|
||||
op.setDimensions(dimensions);
|
||||
return op;
|
||||
}
|
||||
|
||||
public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){
|
||||
switch (distanceFunction){
|
||||
case COSINE_DISTANCE:
|
||||
return new CosineDistance(x,y);
|
||||
case COSINE_SIMILARITY:
|
||||
return new CosineSimilarity(x,y);
|
||||
case DOT:
|
||||
return new Dot(x,y);
|
||||
case EUCLIDEAN:
|
||||
return new EuclideanDistance(x,y);
|
||||
case JACCARD:
|
||||
return new JaccardDistance(x,y);
|
||||
case MANHATTAN:
|
||||
return new ManhattanDistance(x,y);
|
||||
default:
|
||||
throw new IllegalStateException("Unknown distance function: " + distanceFunction);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,107 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.UUID;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class Point implements Serializable {
|
||||
|
||||
private static final long serialVersionUID = -6658028541426027226L;
|
||||
|
||||
private String id = UUID.randomUUID().toString();
|
||||
private String label;
|
||||
private INDArray array;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param array
|
||||
*/
|
||||
public Point(INDArray array) {
|
||||
super();
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param id
|
||||
* @param array
|
||||
*/
|
||||
public Point(String id, INDArray array) {
|
||||
super();
|
||||
this.id = id;
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
public Point(String id, String label, double[] data) {
|
||||
this(id, label, Nd4j.create(data));
|
||||
}
|
||||
|
||||
public Point(String id, String label, INDArray array) {
|
||||
super();
|
||||
this.id = id;
|
||||
this.label = label;
|
||||
this.array = array;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param matrix
|
||||
* @return
|
||||
*/
|
||||
public static List<Point> toPoints(INDArray matrix) {
|
||||
List<Point> arr = new ArrayList<>(matrix.rows());
|
||||
for (int i = 0; i < matrix.rows(); i++) {
|
||||
arr.add(new Point(matrix.getRow(i)));
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param vectors
|
||||
* @return
|
||||
*/
|
||||
public static List<Point> toPoints(List<INDArray> vectors) {
|
||||
List<Point> points = new ArrayList<>();
|
||||
for (INDArray vector : vectors)
|
||||
points.add(new Point(vector));
|
||||
return points;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,40 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
@AllArgsConstructor
|
||||
public class PointClassification implements Serializable {
|
||||
|
||||
private Cluster cluster;
|
||||
private double distanceFromCenter;
|
||||
private boolean newLocation;
|
||||
|
||||
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.condition;
|
||||
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public interface ClusteringAlgorithmCondition {
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationHistory
|
||||
* @return
|
||||
*/
|
||||
boolean isSatisfied(IterationHistory iterationHistory);
|
||||
|
||||
}
|
|
@ -1,69 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.condition;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||
import org.nd4j.linalg.indexing.conditions.LessThan;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
@AllArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable {
|
||||
|
||||
private Condition convergenceCondition;
|
||||
private double pointsDistributionChangeRate;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param pointsDistributionChangeRate
|
||||
* @return
|
||||
*/
|
||||
public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) {
|
||||
Condition condition = new LessThan(pointsDistributionChangeRate);
|
||||
return new ConvergenceCondition(condition, pointsDistributionChangeRate);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationHistory
|
||||
* @return
|
||||
*/
|
||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
||||
int iterationCount = iterationHistory.getIterationCount();
|
||||
if (iterationCount <= 1)
|
||||
return false;
|
||||
|
||||
double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get();
|
||||
variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount();
|
||||
|
||||
return convergenceCondition.apply(variation);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.condition;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||
import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable {
|
||||
|
||||
private Condition iterationCountCondition;
|
||||
|
||||
protected FixedIterationCountCondition(int initialClusterCount) {
|
||||
iterationCountCondition = new GreaterThanOrEqual(initialClusterCount);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationCount
|
||||
* @return
|
||||
*/
|
||||
public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) {
|
||||
return new FixedIterationCountCondition(iterationCount);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationHistory
|
||||
* @return
|
||||
*/
|
||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
||||
return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount());
|
||||
}
|
||||
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.condition;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
import org.nd4j.linalg.indexing.conditions.Condition;
|
||||
import org.nd4j.linalg.indexing.conditions.LessThan;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
@AllArgsConstructor
|
||||
public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable {
|
||||
|
||||
private Condition varianceVariationCondition;
|
||||
private int period;
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param varianceVariation
|
||||
* @param period
|
||||
* @return
|
||||
*/
|
||||
public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) {
|
||||
Condition condition = new LessThan(varianceVariation);
|
||||
return new VarianceVariationCondition(condition, period);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationHistory
|
||||
* @return
|
||||
*/
|
||||
public boolean isSatisfied(IterationHistory iterationHistory) {
|
||||
if (iterationHistory.getIterationCount() <= period)
|
||||
return false;
|
||||
|
||||
for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) {
|
||||
double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo()
|
||||
.getPointDistanceFromClusterVariance();
|
||||
variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
|
||||
.getPointDistanceFromClusterVariance();
|
||||
variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
|
||||
.getPointDistanceFromClusterVariance();
|
||||
|
||||
if (!varianceVariationCondition.apply(variation))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,114 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.info;
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
public class ClusterInfo implements Serializable {
|
||||
|
||||
private double averagePointDistanceFromCenter;
|
||||
private double maxPointDistanceFromCenter;
|
||||
private double pointDistanceFromCenterVariance;
|
||||
private double totalPointDistanceFromCenter;
|
||||
private boolean inverse;
|
||||
private Map<String, Double> pointDistancesFromCenter = new ConcurrentHashMap<>();
|
||||
|
||||
public ClusterInfo(boolean inverse) {
|
||||
this(false, inverse);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param threadSafe
|
||||
*/
|
||||
public ClusterInfo(boolean threadSafe, boolean inverse) {
|
||||
super();
|
||||
this.inverse = inverse;
|
||||
if (threadSafe) {
|
||||
pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Set<Map.Entry<String, Double>> getSortedPointDistancesFromCenter() {
|
||||
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
|
||||
@Override
|
||||
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
|
||||
int res = e1.getValue().compareTo(e2.getValue());
|
||||
return res != 0 ? res : 1;
|
||||
}
|
||||
});
|
||||
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
|
||||
return sortedEntries;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Set<Map.Entry<String, Double>> getReverseSortedPointDistancesFromCenter() {
|
||||
SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() {
|
||||
@Override
|
||||
public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) {
|
||||
int res = e1.getValue().compareTo(e2.getValue());
|
||||
return -(res != 0 ? res : 1);
|
||||
}
|
||||
});
|
||||
sortedEntries.addAll(pointDistancesFromCenter.entrySet());
|
||||
return sortedEntries;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param maxDistance
|
||||
* @return
|
||||
*/
|
||||
public List<String> getPointsFartherFromCenterThan(double maxDistance) {
|
||||
Set<Map.Entry<String, Double>> sorted = getReverseSortedPointDistancesFromCenter();
|
||||
List<String> ids = new ArrayList<>();
|
||||
for (Map.Entry<String, Double> entry : sorted) {
|
||||
if (inverse && entry.getValue() < -maxDistance) {
|
||||
if (entry.getValue() < -maxDistance)
|
||||
break;
|
||||
}
|
||||
|
||||
else if (entry.getValue() > maxDistance)
|
||||
break;
|
||||
|
||||
ids.add(entry.getKey());
|
||||
}
|
||||
return ids;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,142 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.info;
|
||||
|
||||
import org.nd4j.shade.guava.collect.HashBasedTable;
|
||||
import org.nd4j.shade.guava.collect.Table;
|
||||
import org.deeplearning4j.clustering.cluster.Cluster;
|
||||
import org.deeplearning4j.clustering.cluster.ClusterSet;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
public class ClusterSetInfo implements Serializable {
|
||||
|
||||
private Map<String, ClusterInfo> clustersInfos = new HashMap<>();
|
||||
private Table<String, String, Double> distancesBetweenClustersCenters = HashBasedTable.create();
|
||||
private AtomicInteger pointLocationChange;
|
||||
private boolean threadSafe;
|
||||
private boolean inverse;
|
||||
|
||||
public ClusterSetInfo(boolean inverse) {
|
||||
this(inverse, false);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param inverse
|
||||
* @param threadSafe
|
||||
*/
|
||||
public ClusterSetInfo(boolean inverse, boolean threadSafe) {
|
||||
this.pointLocationChange = new AtomicInteger(0);
|
||||
this.threadSafe = threadSafe;
|
||||
this.inverse = inverse;
|
||||
if (threadSafe) {
|
||||
clustersInfos = Collections.synchronizedMap(clustersInfos);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterSet
|
||||
* @param threadSafe
|
||||
* @return
|
||||
*/
|
||||
public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) {
|
||||
ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe);
|
||||
for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++)
|
||||
info.addClusterInfo(clusterSet.getClusters().get(i).getId());
|
||||
return info;
|
||||
}
|
||||
|
||||
public void removeClusterInfos(List<Cluster> clusters) {
|
||||
for (Cluster cluster : clusters) {
|
||||
clustersInfos.remove(cluster.getId());
|
||||
}
|
||||
}
|
||||
|
||||
public ClusterInfo addClusterInfo(String clusterId) {
|
||||
ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe);
|
||||
clustersInfos.put(clusterId, clusterInfo);
|
||||
return clusterInfo;
|
||||
}
|
||||
|
||||
public ClusterInfo getClusterInfo(String clusterId) {
|
||||
return clustersInfos.get(clusterId);
|
||||
}
|
||||
|
||||
public double getAveragePointDistanceFromClusterCenter() {
|
||||
if (clustersInfos == null || clustersInfos.isEmpty())
|
||||
return 0;
|
||||
|
||||
double average = 0;
|
||||
for (ClusterInfo info : clustersInfos.values())
|
||||
average += info.getAveragePointDistanceFromCenter();
|
||||
return average / clustersInfos.size();
|
||||
}
|
||||
|
||||
public double getPointDistanceFromClusterVariance() {
|
||||
if (clustersInfos == null || clustersInfos.isEmpty())
|
||||
return 0;
|
||||
|
||||
double average = 0;
|
||||
for (ClusterInfo info : clustersInfos.values())
|
||||
average += info.getPointDistanceFromCenterVariance();
|
||||
return average / clustersInfos.size();
|
||||
}
|
||||
|
||||
public int getPointsCount() {
|
||||
int count = 0;
|
||||
for (ClusterInfo clusterInfo : clustersInfos.values())
|
||||
count += clusterInfo.getPointDistancesFromCenter().size();
|
||||
return count;
|
||||
}
|
||||
|
||||
public Map<String, ClusterInfo> getClustersInfos() {
|
||||
return clustersInfos;
|
||||
}
|
||||
|
||||
public void setClustersInfos(Map<String, ClusterInfo> clustersInfos) {
|
||||
this.clustersInfos = clustersInfos;
|
||||
}
|
||||
|
||||
public Table<String, String, Double> getDistancesBetweenClustersCenters() {
|
||||
return distancesBetweenClustersCenters;
|
||||
}
|
||||
|
||||
public void setDistancesBetweenClustersCenters(Table<String, String, Double> interClusterDistances) {
|
||||
this.distancesBetweenClustersCenters = interClusterDistances;
|
||||
}
|
||||
|
||||
public AtomicInteger getPointLocationChange() {
|
||||
return pointLocationChange;
|
||||
}
|
||||
|
||||
public void setPointLocationChange(AtomicInteger pointLocationChange) {
|
||||
this.pointLocationChange = pointLocationChange;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.iteration;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
public class IterationHistory implements Serializable {
|
||||
@Getter
|
||||
@Setter
|
||||
private Map<Integer, IterationInfo> iterationsInfos = new HashMap<>();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public ClusterSetInfo getMostRecentClusterSetInfo() {
|
||||
IterationInfo iterationInfo = getMostRecentIterationInfo();
|
||||
return iterationInfo == null ? null : iterationInfo.getClusterSetInfo();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public IterationInfo getMostRecentIterationInfo() {
|
||||
return getIterationInfo(getIterationCount() - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public int getIterationCount() {
|
||||
return getIterationsInfos().size();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationIdx
|
||||
* @return
|
||||
*/
|
||||
public IterationInfo getIterationInfo(int iterationIdx) {
|
||||
return getIterationsInfos().get(iterationIdx);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.iteration;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.clustering.info.ClusterSetInfo;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class IterationInfo implements Serializable {
|
||||
|
||||
private int index;
|
||||
private ClusterSetInfo clusterSetInfo;
|
||||
private boolean strategyApplied;
|
||||
|
||||
public IterationInfo(int index) {
|
||||
super();
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
public IterationInfo(int index, ClusterSetInfo clusterSetInfo) {
|
||||
super();
|
||||
this.index = index;
|
||||
this.clusterSetInfo = clusterSetInfo;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,142 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.kdtree;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
public class HyperRect implements Serializable {
|
||||
|
||||
//private List<Interval> points;
|
||||
private float[] lowerEnds;
|
||||
private float[] higherEnds;
|
||||
private INDArray lowerEndsIND;
|
||||
private INDArray higherEndsIND;
|
||||
|
||||
public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
|
||||
this.lowerEnds = new float[lowerEndsIn.length];
|
||||
this.higherEnds = new float[lowerEndsIn.length];
|
||||
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
|
||||
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
|
||||
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
|
||||
higherEndsIND = Nd4j.createFromArray(higherEnds);
|
||||
}
|
||||
|
||||
public HyperRect(float[] point) {
|
||||
this(point, point);
|
||||
}
|
||||
|
||||
public HyperRect(Pair<float[], float[]> ends) {
|
||||
this(ends.getFirst(), ends.getSecond());
|
||||
}
|
||||
|
||||
|
||||
public void enlargeTo(INDArray point) {
|
||||
float[] pointAsArray = point.toFloatVector();
|
||||
for (int i = 0; i < lowerEnds.length; i++) {
|
||||
float p = pointAsArray[i];
|
||||
if (lowerEnds[i] > p)
|
||||
lowerEnds[i] = p;
|
||||
else if (higherEnds[i] < p)
|
||||
higherEnds[i] = p;
|
||||
}
|
||||
}
|
||||
|
||||
public static Pair<float[],float[]> point(INDArray vector) {
|
||||
Pair<float[],float[]> ret = new Pair<>();
|
||||
float[] curr = new float[(int)vector.length()];
|
||||
for (int i = 0; i < vector.length(); i++) {
|
||||
curr[i] = vector.getFloat(i);
|
||||
}
|
||||
ret.setFirst(curr);
|
||||
ret.setSecond(curr);
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
/*public List<Boolean> contains(INDArray hPoint) {
|
||||
List<Boolean> ret = new ArrayList<>();
|
||||
for (int i = 0; i < hPoint.length(); i++) {
|
||||
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
|
||||
higherEnds[i] >= hPoint.getDouble(i));
|
||||
}
|
||||
return ret;
|
||||
}*/
|
||||
|
||||
public double minDistance(INDArray hPoint, INDArray output) {
|
||||
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
|
||||
return output.getFloat(0);
|
||||
|
||||
/*double ret = 0.0;
|
||||
double[] pointAsArray = hPoint.toDoubleVector();
|
||||
for (int i = 0; i < pointAsArray.length; i++) {
|
||||
double p = pointAsArray[i];
|
||||
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
|
||||
if (p < lowerEnds[i])
|
||||
ret += Math.pow((p - lowerEnds[i]), 2);
|
||||
else
|
||||
ret += Math.pow((p - higherEnds[i]), 2);
|
||||
}
|
||||
}
|
||||
ret = Math.pow(ret, 0.5);
|
||||
return ret;*/
|
||||
}
|
||||
|
||||
public HyperRect getUpper(INDArray hPoint, int desc) {
|
||||
//Interval interval = points.get(desc);
|
||||
float higher = higherEnds[desc];
|
||||
float d = hPoint.getFloat(desc);
|
||||
if (higher < d)
|
||||
return null;
|
||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||
if (ret.lowerEnds[desc] < d)
|
||||
ret.lowerEnds[desc] = d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
public HyperRect getLower(INDArray hPoint, int desc) {
|
||||
//Interval interval = points.get(desc);
|
||||
float lower = lowerEnds[desc];
|
||||
float d = hPoint.getFloat(desc);
|
||||
if (lower > d)
|
||||
return null;
|
||||
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
|
||||
//Interval i2 = ret.points.get(desc);
|
||||
if (ret.higherEnds[desc] > d)
|
||||
ret.higherEnds[desc] = d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
String retVal = "";
|
||||
retVal += "[";
|
||||
for (int i = 0; i < lowerEnds.length; ++i) {
|
||||
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
|
||||
}
|
||||
retVal += "]";
|
||||
return retVal;
|
||||
}
|
||||
}
|
|
@ -1,370 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.kdtree;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.List;
|
||||
|
||||
public class KDTree implements Serializable {
|
||||
|
||||
private KDNode root;
|
||||
private int dims = 100;
|
||||
public final static int GREATER = 1;
|
||||
public final static int LESS = 0;
|
||||
private int size = 0;
|
||||
private HyperRect rect;
|
||||
|
||||
public KDTree(int dims) {
|
||||
this.dims = dims;
|
||||
}
|
||||
|
||||
/**
|
||||
* Insert a point in to the tree
|
||||
* @param point the point to insert
|
||||
*/
|
||||
public void insert(INDArray point) {
|
||||
if (!point.isVector() || point.length() != dims)
|
||||
throw new IllegalArgumentException("Point must be a vector of length " + dims);
|
||||
|
||||
if (root == null) {
|
||||
root = new KDNode(point);
|
||||
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
|
||||
} else {
|
||||
int disc = 0;
|
||||
KDNode node = root;
|
||||
KDNode insert = new KDNode(point);
|
||||
int successor;
|
||||
while (true) {
|
||||
//exactly equal
|
||||
INDArray pt = node.getPoint();
|
||||
INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z();
|
||||
if (countEq.getInt(0) == 0) {
|
||||
return;
|
||||
} else {
|
||||
successor = successor(node, point, disc);
|
||||
KDNode child;
|
||||
if (successor < 1)
|
||||
child = node.getLeft();
|
||||
else
|
||||
child = node.getRight();
|
||||
if (child == null)
|
||||
break;
|
||||
disc = (disc + 1) % dims;
|
||||
node = child;
|
||||
}
|
||||
}
|
||||
|
||||
if (successor < 1)
|
||||
node.setLeft(insert);
|
||||
|
||||
else
|
||||
node.setRight(insert);
|
||||
|
||||
rect.enlargeTo(point);
|
||||
insert.setParent(node);
|
||||
}
|
||||
size++;
|
||||
|
||||
}
|
||||
|
||||
|
||||
public INDArray delete(INDArray point) {
|
||||
KDNode node = root;
|
||||
int _disc = 0;
|
||||
while (node != null) {
|
||||
if (node.point == point)
|
||||
break;
|
||||
int successor = successor(node, point, _disc);
|
||||
if (successor < 1)
|
||||
node = node.getLeft();
|
||||
else
|
||||
node = node.getRight();
|
||||
_disc = (_disc + 1) % dims;
|
||||
}
|
||||
|
||||
if (node != null) {
|
||||
if (node == root) {
|
||||
root = delete(root, _disc);
|
||||
} else
|
||||
node = delete(node, _disc);
|
||||
size--;
|
||||
if (size == 1) {
|
||||
rect = new HyperRect(HyperRect.point(point));
|
||||
} else if (size == 0)
|
||||
rect = null;
|
||||
|
||||
}
|
||||
return node.getPoint();
|
||||
}
|
||||
|
||||
// Share this data for recursive calls of "knn"
|
||||
private float currentDistance;
|
||||
private INDArray currentPoint;
|
||||
private INDArray minDistance = Nd4j.scalar(0.f);
|
||||
|
||||
|
||||
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
|
||||
List<Pair<Float, INDArray>> best = new ArrayList<>();
|
||||
currentDistance = distance;
|
||||
currentPoint = point;
|
||||
knn(root, rect, best, 0);
|
||||
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
|
||||
@Override
|
||||
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
|
||||
return Float.compare(o1.getKey(), o2.getKey());
|
||||
}
|
||||
});
|
||||
|
||||
return best;
|
||||
}
|
||||
|
||||
|
||||
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
|
||||
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
|
||||
return;
|
||||
int _discNext = (_disc + 1) % dims;
|
||||
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
|
||||
.floatValue();
|
||||
|
||||
if (distance <= currentDistance) {
|
||||
best.add(Pair.of(distance, node.getPoint()));
|
||||
}
|
||||
|
||||
HyperRect lower = rect.getLower(node.point, _disc);
|
||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
||||
knn(node.getLeft(), lower, best, _discNext);
|
||||
knn(node.getRight(), upper, best, _discNext);
|
||||
}
|
||||
|
||||
/**
|
||||
* Query for nearest neighbor. Returns the distance and point
|
||||
* @param point the point to query for
|
||||
* @return
|
||||
*/
|
||||
public Pair<Double, INDArray> nn(INDArray point) {
|
||||
return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0);
|
||||
}
|
||||
|
||||
|
||||
private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
|
||||
int _disc) {
|
||||
if (node == null || rect.minDistance(point, minDistance) > dist)
|
||||
return Pair.of(Double.POSITIVE_INFINITY, null);
|
||||
|
||||
int _discNext = (_disc + 1) % dims;
|
||||
double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue();
|
||||
if (dist2 < dist) {
|
||||
best = node.getPoint();
|
||||
dist = dist2;
|
||||
}
|
||||
|
||||
HyperRect lower = rect.getLower(node.point, _disc);
|
||||
HyperRect upper = rect.getUpper(node.point, _disc);
|
||||
|
||||
if (point.getDouble(_disc) < node.point.getDouble(_disc)) {
|
||||
Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext);
|
||||
Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext);
|
||||
if (left.getKey() < dist)
|
||||
return left;
|
||||
else if (right.getKey() < dist)
|
||||
return right;
|
||||
|
||||
} else {
|
||||
Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext);
|
||||
Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext);
|
||||
if (left.getKey() < dist)
|
||||
return left;
|
||||
else if (right.getKey() < dist)
|
||||
return right;
|
||||
}
|
||||
|
||||
return Pair.of(dist, best);
|
||||
|
||||
}
|
||||
|
||||
private KDNode delete(KDNode delete, int _disc) {
|
||||
if (delete.getLeft() != null && delete.getRight() != null) {
|
||||
if (delete.getParent() != null) {
|
||||
if (delete.getParent().getLeft() == delete)
|
||||
delete.getParent().setLeft(null);
|
||||
else
|
||||
delete.getParent().setRight(null);
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
int disc = _disc;
|
||||
_disc = (_disc + 1) % dims;
|
||||
Pair<KDNode, Integer> qd = null;
|
||||
if (delete.getRight() != null) {
|
||||
qd = min(delete.getRight(), disc, _disc);
|
||||
} else if (delete.getLeft() != null)
|
||||
qd = max(delete.getLeft(), disc, _disc);
|
||||
if (qd == null) {// is leaf
|
||||
return null;
|
||||
}
|
||||
delete.point = qd.getKey().point;
|
||||
KDNode qFather = qd.getKey().getParent();
|
||||
if (qFather.getLeft() == qd.getKey()) {
|
||||
qFather.setLeft(delete(qd.getKey(), disc));
|
||||
} else if (qFather.getRight() == qd.getKey()) {
|
||||
qFather.setRight(delete(qd.getKey(), disc));
|
||||
|
||||
}
|
||||
|
||||
return delete;
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) {
|
||||
int discNext = (_disc + 1) % dims;
|
||||
if (_disc == disc) {
|
||||
KDNode child = node.getLeft();
|
||||
if (child != null) {
|
||||
return max(child, disc, discNext);
|
||||
}
|
||||
} else if (node.getLeft() != null || node.getRight() != null) {
|
||||
Pair<KDNode, Integer> left = null, right = null;
|
||||
if (node.getLeft() != null)
|
||||
left = max(node.getLeft(), disc, discNext);
|
||||
if (node.getRight() != null)
|
||||
right = max(node.getRight(), disc, discNext);
|
||||
if (left != null && right != null) {
|
||||
double pointLeft = left.getKey().getPoint().getDouble(disc);
|
||||
double pointRight = right.getKey().getPoint().getDouble(disc);
|
||||
if (pointLeft > pointRight)
|
||||
return left;
|
||||
else
|
||||
return right;
|
||||
} else if (left != null)
|
||||
return left;
|
||||
else
|
||||
return right;
|
||||
}
|
||||
|
||||
return Pair.of(node, _disc);
|
||||
}
|
||||
|
||||
|
||||
|
||||
private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) {
|
||||
int discNext = (_disc + 1) % dims;
|
||||
if (_disc == disc) {
|
||||
KDNode child = node.getLeft();
|
||||
if (child != null) {
|
||||
return min(child, disc, discNext);
|
||||
}
|
||||
} else if (node.getLeft() != null || node.getRight() != null) {
|
||||
Pair<KDNode, Integer> left = null, right = null;
|
||||
if (node.getLeft() != null)
|
||||
left = min(node.getLeft(), disc, discNext);
|
||||
if (node.getRight() != null)
|
||||
right = min(node.getRight(), disc, discNext);
|
||||
if (left != null && right != null) {
|
||||
double pointLeft = left.getKey().getPoint().getDouble(disc);
|
||||
double pointRight = right.getKey().getPoint().getDouble(disc);
|
||||
if (pointLeft < pointRight)
|
||||
return left;
|
||||
else
|
||||
return right;
|
||||
} else if (left != null)
|
||||
return left;
|
||||
else
|
||||
return right;
|
||||
}
|
||||
|
||||
return Pair.of(node, _disc);
|
||||
}
|
||||
|
||||
/**
|
||||
* The number of elements in the tree
|
||||
* @return the number of elements in the tree
|
||||
*/
|
||||
public int size() {
|
||||
return size;
|
||||
}
|
||||
|
||||
private int successor(KDNode node, INDArray point, int disc) {
|
||||
for (int i = disc; i < dims; i++) {
|
||||
double pointI = point.getDouble(i);
|
||||
double nodePointI = node.getPoint().getDouble(i);
|
||||
if (pointI < nodePointI)
|
||||
return LESS;
|
||||
else if (pointI > nodePointI)
|
||||
return GREATER;
|
||||
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Point is equal!");
|
||||
}
|
||||
|
||||
|
||||
private static class KDNode {
|
||||
private INDArray point;
|
||||
private KDNode left, right, parent;
|
||||
|
||||
public KDNode(INDArray point) {
|
||||
this.point = point;
|
||||
}
|
||||
|
||||
public INDArray getPoint() {
|
||||
return point;
|
||||
}
|
||||
|
||||
public KDNode getLeft() {
|
||||
return left;
|
||||
}
|
||||
|
||||
public void setLeft(KDNode left) {
|
||||
this.left = left;
|
||||
}
|
||||
|
||||
public KDNode getRight() {
|
||||
return right;
|
||||
}
|
||||
|
||||
public void setRight(KDNode right) {
|
||||
this.right = right;
|
||||
}
|
||||
|
||||
public KDNode getParent() {
|
||||
return parent;
|
||||
}
|
||||
|
||||
public void setParent(KDNode parent) {
|
||||
this.parent = parent;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,109 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.kmeans;
|
||||
|
||||
import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
|
||||
import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy;
|
||||
|
||||
|
||||
public class KMeansClustering extends BaseClusteringAlgorithm {
|
||||
|
||||
private static final long serialVersionUID = 8476951388145944776L;
|
||||
private static final double VARIATION_TOLERANCE= 1e-4;
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusteringStrategy
|
||||
*/
|
||||
protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
|
||||
super(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
* Setup a kmeans instance
|
||||
* @param clusterCount the number of clusters
|
||||
* @param maxIterationCount the max number of iterations
|
||||
* to run kmeans
|
||||
* @param distanceFunction the distance function to use for grouping
|
||||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
|
||||
boolean inverse, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy =
|
||||
FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
|
||||
clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterCount
|
||||
* @param minDistributionVariationRate
|
||||
* @param distanceFunction
|
||||
* @param allowEmptyClusters
|
||||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||
boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
|
||||
.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Setup a kmeans instance
|
||||
* @param clusterCount the number of clusters
|
||||
* @param maxIterationCount the max number of iterations
|
||||
* to run kmeans
|
||||
* @param distanceFunction the distance function to use for grouping
|
||||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
|
||||
return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterCount
|
||||
* @param minDistributionVariationRate
|
||||
* @param distanceFunction
|
||||
* @param allowEmptyClusters
|
||||
* @return
|
||||
*/
|
||||
public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
|
||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
|
||||
boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
|
||||
ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
|
||||
clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
|
||||
return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
|
||||
}
|
||||
|
||||
}
|
|
@ -1,88 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.lsh;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
public interface LSH {
|
||||
|
||||
/**
|
||||
* Returns an instance of the distance measure associated to the LSH family of this implementation.
|
||||
* Beware, hashing families and their amplification constructs are distance-specific.
|
||||
*/
|
||||
String getDistanceMeasure();
|
||||
|
||||
/**
|
||||
* Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction
|
||||
*
|
||||
* denoting hashLength by h,
|
||||
* amplifies a (d1, d2, p1, p2) hash family into a
|
||||
* (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h)
|
||||
*
|
||||
* @return the length of the hash in the AND construction used by this index
|
||||
*/
|
||||
int getHashLength();
|
||||
|
||||
/**
|
||||
*
|
||||
* denoting numTables by n,
|
||||
* amplifies a (d1, d2, p1, p2) hash family into a
|
||||
* (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n)
|
||||
*
|
||||
* @return the # of hash tables in the OR construction used by this index
|
||||
*/
|
||||
int getNumTables();
|
||||
|
||||
/**
|
||||
* @return The dimension of the index vectors and queries
|
||||
*/
|
||||
int getInDimension();
|
||||
|
||||
/**
|
||||
* Populates the index with data vectors.
|
||||
* @param data the vectors to index
|
||||
*/
|
||||
void makeIndex(INDArray data);
|
||||
|
||||
/**
|
||||
* Returns the set of all vectors that could approximately be considered negihbors of the query,
|
||||
* without selection on the basis of distance or number of neighbors.
|
||||
* @param query a vector to find neighbors for
|
||||
* @return its approximate neighbors, unfiltered
|
||||
*/
|
||||
INDArray bucket(INDArray query);
|
||||
|
||||
/**
|
||||
* Returns the approximate neighbors within a distance bound.
|
||||
* @param query a vector to find neighbors for
|
||||
* @param maxRange the maximum distance between results and the query
|
||||
* @return approximate neighbors within the distance bounds
|
||||
*/
|
||||
INDArray search(INDArray query, double maxRange);
|
||||
|
||||
/**
|
||||
* Returns the approximate neighbors within a k-closest bound
|
||||
* @param query a vector to find neighbors for
|
||||
* @param k the maximum number of closest neighbors to return
|
||||
* @return at most k neighbors of the query, ordered by increasing distance
|
||||
*/
|
||||
INDArray search(INDArray query, int k);
|
||||
}
|
|
@ -1,227 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.lsh;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.val;
|
||||
import org.nd4j.common.base.Preconditions;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
|
||||
import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
|
||||
import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
|
||||
import org.nd4j.linalg.api.rng.Random;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.indexing.BooleanIndexing;
|
||||
import org.nd4j.linalg.indexing.conditions.Conditions;
|
||||
import org.nd4j.linalg.ops.transforms.Transforms;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
|
||||
public class RandomProjectionLSH implements LSH {
|
||||
|
||||
@Override
|
||||
public String getDistanceMeasure(){
|
||||
return "cosinedistance";
|
||||
}
|
||||
|
||||
@Getter private int hashLength;
|
||||
|
||||
@Getter private int numTables;
|
||||
|
||||
@Getter private int inDimension;
|
||||
|
||||
|
||||
@Getter private double radius;
|
||||
|
||||
INDArray randomProjection;
|
||||
|
||||
INDArray index;
|
||||
|
||||
INDArray indexData;
|
||||
|
||||
|
||||
private INDArray gaussianRandomMatrix(int[] shape, Random rng){
|
||||
INDArray res = Nd4j.create(shape);
|
||||
|
||||
GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
|
||||
|
||||
Nd4j.getExecutioner().exec(op1, rng);
|
||||
return res;
|
||||
}
|
||||
|
||||
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){
|
||||
this(hashLength, numTables, inDimension, radius, Nd4j.getRandom());
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a locality-sensitive hashing index for the cosine distance,
|
||||
* a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification
|
||||
*
|
||||
* @param hashLength the length of the compared hash in an AND construction,
|
||||
* @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple
|
||||
* probes of Panigraphi (op. cit).
|
||||
* @param inDimension the dimendionality of the points being indexed
|
||||
* @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction
|
||||
* @param rng a Random object to draw samples from
|
||||
*/
|
||||
public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){
|
||||
this.hashLength = hashLength;
|
||||
this.numTables = numTables;
|
||||
this.inDimension = inDimension;
|
||||
this.radius = radius;
|
||||
randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng);
|
||||
}
|
||||
|
||||
/**
|
||||
* This picks uniformaly distributed random points on the unit of a sphere using the method of:
|
||||
*
|
||||
* An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere
|
||||
* JS Hicks, RF Wheeling - Communications of the ACM, 1959
|
||||
* @param data a query to generate multiple probes for
|
||||
* @return `numTables`
|
||||
*/
|
||||
public INDArray entropy(INDArray data){
|
||||
|
||||
INDArray data2 =
|
||||
Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius));
|
||||
|
||||
INDArray norms = Nd4j.norm2(data2.dup(), -1);
|
||||
|
||||
Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms);
|
||||
|
||||
data2.diviColumnVector(norms);
|
||||
data2.addiRowVector(data);
|
||||
return data2;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns hash values for a particular query
|
||||
* @param data a query vector
|
||||
* @return its hashed value
|
||||
*/
|
||||
public INDArray hash(INDArray data) {
|
||||
if (data.shape()[1] != inDimension){
|
||||
throw new ND4JIllegalStateException(
|
||||
String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d",
|
||||
Arrays.toString(data.shape()), inDimension));
|
||||
}
|
||||
INDArray projected = data.mmul(randomProjection);
|
||||
INDArray res = Nd4j.getExecutioner().exec(new Sign(projected));
|
||||
return res;
|
||||
}
|
||||
|
||||
/**
|
||||
* Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it.
|
||||
* @param data the vectors to index
|
||||
*/
|
||||
@Override
|
||||
public void makeIndex(INDArray data) {
|
||||
index = hash(data);
|
||||
indexData = data;
|
||||
}
|
||||
|
||||
// data elements in the same bucket as the query, without entropy
|
||||
INDArray rawBucketOf(INDArray query){
|
||||
INDArray pattern = hash(query);
|
||||
|
||||
INDArray res = Nd4j.zeros(DataType.BOOL, index.shape());
|
||||
Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1));
|
||||
return res.castTo(Nd4j.defaultFloatingPointType()).min(-1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray bucket(INDArray query) {
|
||||
INDArray queryRes = rawBucketOf(query);
|
||||
|
||||
if(numTables > 1) {
|
||||
INDArray entropyQueries = entropy(query);
|
||||
|
||||
// loop, addi + conditionalreplace -> poor man's OR function
|
||||
for (int i = 0; i < numTables; i++) {
|
||||
INDArray row = entropyQueries.getRow(i, true);
|
||||
queryRes.addi(rawBucketOf(row));
|
||||
}
|
||||
BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0));
|
||||
}
|
||||
|
||||
return queryRes;
|
||||
}
|
||||
|
||||
// data elements in the same entropy bucket as the query,
|
||||
INDArray bucketData(INDArray query){
|
||||
INDArray mask = bucket(query);
|
||||
int nRes = mask.sum(0).getInt(0);
|
||||
INDArray res = Nd4j.create(new int[] {nRes, inDimension});
|
||||
int j = 0;
|
||||
for (int i = 0; i < nRes; i++){
|
||||
while (mask.getInt(j) == 0 && j < mask.length() - 1) {
|
||||
j += 1;
|
||||
}
|
||||
if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j));
|
||||
j += 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray search(INDArray query, double maxRange) {
|
||||
if (maxRange < 0)
|
||||
throw new IllegalArgumentException("ANN search should have a positive maximum search radius");
|
||||
|
||||
INDArray bucketData = bucketData(query);
|
||||
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
|
||||
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
|
||||
|
||||
INDArray shuffleIndexes = idxs[0];
|
||||
INDArray sortedDistances = idxs[1];
|
||||
int accepted = 0;
|
||||
while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1;
|
||||
|
||||
INDArray res = Nd4j.create(new int[] {accepted, inDimension});
|
||||
for(int i = 0; i < accepted; i++){
|
||||
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@Override
|
||||
public INDArray search(INDArray query, int k) {
|
||||
if (k < 1)
|
||||
throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor");
|
||||
|
||||
INDArray bucketData = bucketData(query);
|
||||
INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
|
||||
INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
|
||||
|
||||
INDArray shuffleIndexes = idxs[0];
|
||||
INDArray sortedDistances = idxs[1];
|
||||
val accepted = Math.min(k, sortedDistances.shape()[1]);
|
||||
|
||||
INDArray res = Nd4j.create(accepted, inDimension);
|
||||
for(int i = 0; i < accepted; i++){
|
||||
res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
|
||||
}
|
||||
return res;
|
||||
}
|
||||
}
|
|
@ -1,38 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.optimisation;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.AllArgsConstructor;
|
||||
import lombok.Data;
|
||||
import lombok.NoArgsConstructor;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
@AllArgsConstructor
|
||||
public class ClusteringOptimization implements Serializable {
|
||||
|
||||
private ClusteringOptimizationType type;
|
||||
private double value;
|
||||
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.optimisation;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public enum ClusteringOptimizationType {
|
||||
MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT
|
||||
}
|
|
@ -1,115 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.quadtree;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
public class Cell implements Serializable {
|
||||
private double x, y, hw, hh;
|
||||
|
||||
public Cell(double x, double y, double hw, double hh) {
|
||||
this.x = x;
|
||||
this.y = y;
|
||||
this.hw = hw;
|
||||
this.hh = hh;
|
||||
}
|
||||
|
||||
/**
|
||||
* Whether the given point is contained
|
||||
* within this cell
|
||||
* @param point the point to check
|
||||
* @return true if the point is contained, false otherwise
|
||||
*/
|
||||
public boolean containsPoint(INDArray point) {
|
||||
double first = point.getDouble(0), second = point.getDouble(1);
|
||||
return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (!(o instanceof Cell))
|
||||
return false;
|
||||
|
||||
Cell cell = (Cell) o;
|
||||
|
||||
if (Double.compare(cell.hh, hh) != 0)
|
||||
return false;
|
||||
if (Double.compare(cell.hw, hw) != 0)
|
||||
return false;
|
||||
if (Double.compare(cell.x, x) != 0)
|
||||
return false;
|
||||
return Double.compare(cell.y, y) == 0;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result;
|
||||
long temp;
|
||||
temp = Double.doubleToLongBits(x);
|
||||
result = (int) (temp ^ (temp >>> 32));
|
||||
temp = Double.doubleToLongBits(y);
|
||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
||||
temp = Double.doubleToLongBits(hw);
|
||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
||||
temp = Double.doubleToLongBits(hh);
|
||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
||||
return result;
|
||||
}
|
||||
|
||||
public double getX() {
|
||||
return x;
|
||||
}
|
||||
|
||||
public void setX(double x) {
|
||||
this.x = x;
|
||||
}
|
||||
|
||||
public double getY() {
|
||||
return y;
|
||||
}
|
||||
|
||||
public void setY(double y) {
|
||||
this.y = y;
|
||||
}
|
||||
|
||||
public double getHw() {
|
||||
return hw;
|
||||
}
|
||||
|
||||
public void setHw(double hw) {
|
||||
this.hw = hw;
|
||||
}
|
||||
|
||||
public double getHh() {
|
||||
return hh;
|
||||
}
|
||||
|
||||
public void setHh(double hh) {
|
||||
this.hh = hh;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,383 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.quadtree;
|
||||
|
||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||
import org.apache.commons.math3.util.FastMath;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
import static java.lang.Math.max;
|
||||
|
||||
public class QuadTree implements Serializable {
|
||||
private QuadTree parent, northWest, northEast, southWest, southEast;
|
||||
private boolean isLeaf = true;
|
||||
private int size, cumSize;
|
||||
private Cell boundary;
|
||||
static final int QT_NO_DIMS = 2;
|
||||
static final int QT_NODE_CAPACITY = 1;
|
||||
private INDArray buf = Nd4j.create(QT_NO_DIMS);
|
||||
private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS);
|
||||
private int[] index = new int[QT_NODE_CAPACITY];
|
||||
|
||||
|
||||
/**
|
||||
* Pass in a matrix
|
||||
* @param data
|
||||
*/
|
||||
public QuadTree(INDArray data) {
|
||||
INDArray meanY = data.mean(0);
|
||||
INDArray minY = data.min(0);
|
||||
INDArray maxY = data.max(0);
|
||||
init(data, meanY.getDouble(0), meanY.getDouble(1),
|
||||
max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0))
|
||||
+ Nd4j.EPS_THRESHOLD,
|
||||
max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1))
|
||||
+ Nd4j.EPS_THRESHOLD);
|
||||
fill();
|
||||
}
|
||||
|
||||
public QuadTree(QuadTree parent, INDArray data, Cell boundary) {
|
||||
this.parent = parent;
|
||||
this.boundary = boundary;
|
||||
this.data = data;
|
||||
|
||||
}
|
||||
|
||||
public QuadTree(Cell boundary) {
|
||||
this.boundary = boundary;
|
||||
}
|
||||
|
||||
private void init(INDArray data, double x, double y, double hw, double hh) {
|
||||
boundary = new Cell(x, y, hw, hh);
|
||||
this.data = data;
|
||||
}
|
||||
|
||||
private void fill() {
|
||||
for (int i = 0; i < data.rows(); i++)
|
||||
insert(i);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Returns the cell of this element
|
||||
*
|
||||
* @param coordinates
|
||||
* @return
|
||||
*/
|
||||
protected QuadTree findIndex(INDArray coordinates) {
|
||||
|
||||
// Compute the sector for the coordinates
|
||||
boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2));
|
||||
boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2));
|
||||
|
||||
// top left
|
||||
QuadTree index = getNorthWest();
|
||||
if (left) {
|
||||
// left side
|
||||
if (!top) {
|
||||
// bottom left
|
||||
index = getSouthWest();
|
||||
}
|
||||
} else {
|
||||
// right side
|
||||
if (top) {
|
||||
// top right
|
||||
index = getNorthEast();
|
||||
} else {
|
||||
// bottom right
|
||||
index = getSouthEast();
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
return index;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Insert an index of the data in to the tree
|
||||
* @param newIndex the index to insert in to the tree
|
||||
* @return whether the index was inserted or not
|
||||
*/
|
||||
public boolean insert(int newIndex) {
|
||||
// Ignore objects which do not belong in this quad tree
|
||||
INDArray point = data.slice(newIndex);
|
||||
if (!boundary.containsPoint(point))
|
||||
return false;
|
||||
|
||||
cumSize++;
|
||||
double mult1 = (double) (cumSize - 1) / (double) cumSize;
|
||||
double mult2 = 1.0 / (double) cumSize;
|
||||
|
||||
centerOfMass.muli(mult1);
|
||||
centerOfMass.addi(point.mul(mult2));
|
||||
|
||||
// If there is space in this quad tree and it is a leaf, add the object here
|
||||
if (isLeaf() && size < QT_NODE_CAPACITY) {
|
||||
index[size] = newIndex;
|
||||
size++;
|
||||
return true;
|
||||
}
|
||||
|
||||
//duplicate point
|
||||
if (size > 0) {
|
||||
for (int i = 0; i < size; i++) {
|
||||
INDArray compPoint = data.slice(index[i]);
|
||||
if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1))
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
// If this Node has already been subdivided just add the elements to the
|
||||
// appropriate cell
|
||||
if (!isLeaf()) {
|
||||
QuadTree index = findIndex(point);
|
||||
index.insert(newIndex);
|
||||
return true;
|
||||
}
|
||||
|
||||
if (isLeaf())
|
||||
subDivide();
|
||||
|
||||
boolean ret = insertIntoOneOf(newIndex);
|
||||
|
||||
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
private boolean insertIntoOneOf(int index) {
|
||||
boolean success = false;
|
||||
success = northWest.insert(index);
|
||||
if (!success)
|
||||
success = northEast.insert(index);
|
||||
if (!success)
|
||||
success = southWest.insert(index);
|
||||
if (!success)
|
||||
success = southEast.insert(index);
|
||||
return success;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Returns whether the tree is consistent or not
|
||||
* @return whether the tree is consistent or not
|
||||
*/
|
||||
public boolean isCorrect() {
|
||||
|
||||
for (int n = 0; n < size; n++) {
|
||||
INDArray point = data.slice(index[n]);
|
||||
if (!boundary.containsPoint(point))
|
||||
return false;
|
||||
}
|
||||
|
||||
return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect()
|
||||
&& southEast.isCorrect();
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Create four children
|
||||
* which fully divide this cell
|
||||
* into four quads of equal area
|
||||
*/
|
||||
public void subDivide() {
|
||||
northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
|
||||
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
||||
northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
|
||||
boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
||||
southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
|
||||
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
||||
southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
|
||||
boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute non edge forces using barnes hut
|
||||
* @param pointIndex
|
||||
* @param theta
|
||||
* @param negativeForce
|
||||
* @param sumQ
|
||||
*/
|
||||
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
|
||||
// Make sure that we spend no time on empty nodes or self-interactions
|
||||
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
|
||||
return;
|
||||
|
||||
|
||||
// Compute distance between point and center-of-mass
|
||||
buf.assign(data.slice(pointIndex)).subi(centerOfMass);
|
||||
|
||||
double D = Nd4j.getBlasWrapper().dot(buf, buf);
|
||||
|
||||
// Check whether we can use this node as a "summary"
|
||||
if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) {
|
||||
|
||||
// Compute and add t-SNE force between point and current node
|
||||
double Q = 1.0 / (1.0 + D);
|
||||
double mult = cumSize * Q;
|
||||
sumQ.addAndGet(mult);
|
||||
mult *= Q;
|
||||
negativeForce.addi(buf.mul(mult));
|
||||
|
||||
} else {
|
||||
|
||||
// Recursively apply Barnes-Hut to children
|
||||
northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
||||
northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
||||
southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
||||
southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param rowP a vector
|
||||
* @param colP
|
||||
* @param valP
|
||||
* @param N
|
||||
* @param posF
|
||||
*/
|
||||
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
|
||||
if (!rowP.isVector())
|
||||
throw new IllegalArgumentException("RowP must be a vector");
|
||||
|
||||
// Loop over all edges in the graph
|
||||
double D;
|
||||
for (int n = 0; n < N; n++) {
|
||||
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
|
||||
|
||||
// Compute pairwise distance and Q-value
|
||||
buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i)));
|
||||
|
||||
D = Nd4j.getBlasWrapper().dot(buf, buf);
|
||||
D = valP.getDouble(i) / D;
|
||||
|
||||
// Sum positive force
|
||||
posF.slice(n).addi(buf.mul(D));
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The depth of the node
|
||||
* @return the depth of the node
|
||||
*/
|
||||
public int depth() {
|
||||
if (isLeaf())
|
||||
return 1;
|
||||
return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth()));
|
||||
}
|
||||
|
||||
public INDArray getCenterOfMass() {
|
||||
return centerOfMass;
|
||||
}
|
||||
|
||||
public void setCenterOfMass(INDArray centerOfMass) {
|
||||
this.centerOfMass = centerOfMass;
|
||||
}
|
||||
|
||||
public QuadTree getParent() {
|
||||
return parent;
|
||||
}
|
||||
|
||||
public void setParent(QuadTree parent) {
|
||||
this.parent = parent;
|
||||
}
|
||||
|
||||
public QuadTree getNorthWest() {
|
||||
return northWest;
|
||||
}
|
||||
|
||||
public void setNorthWest(QuadTree northWest) {
|
||||
this.northWest = northWest;
|
||||
}
|
||||
|
||||
public QuadTree getNorthEast() {
|
||||
return northEast;
|
||||
}
|
||||
|
||||
public void setNorthEast(QuadTree northEast) {
|
||||
this.northEast = northEast;
|
||||
}
|
||||
|
||||
public QuadTree getSouthWest() {
|
||||
return southWest;
|
||||
}
|
||||
|
||||
public void setSouthWest(QuadTree southWest) {
|
||||
this.southWest = southWest;
|
||||
}
|
||||
|
||||
public QuadTree getSouthEast() {
|
||||
return southEast;
|
||||
}
|
||||
|
||||
public void setSouthEast(QuadTree southEast) {
|
||||
this.southEast = southEast;
|
||||
}
|
||||
|
||||
public boolean isLeaf() {
|
||||
return isLeaf;
|
||||
}
|
||||
|
||||
public void setLeaf(boolean isLeaf) {
|
||||
this.isLeaf = isLeaf;
|
||||
}
|
||||
|
||||
public int getSize() {
|
||||
return size;
|
||||
}
|
||||
|
||||
public void setSize(int size) {
|
||||
this.size = size;
|
||||
}
|
||||
|
||||
public int getCumSize() {
|
||||
return cumSize;
|
||||
}
|
||||
|
||||
public void setCumSize(int cumSize) {
|
||||
this.cumSize = cumSize;
|
||||
}
|
||||
|
||||
public Cell getBoundary() {
|
||||
return boundary;
|
||||
}
|
||||
|
||||
public void setBoundary(Cell boundary) {
|
||||
this.boundary = boundary;
|
||||
}
|
||||
}
|
|
@ -1,104 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.randomprojection;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@Data
|
||||
public class RPForest {
|
||||
|
||||
private int numTrees;
|
||||
private List<RPTree> trees;
|
||||
private INDArray data;
|
||||
private int maxSize = 1000;
|
||||
private String similarityFunction;
|
||||
|
||||
/**
|
||||
* Create the rp forest with the specified number of trees
|
||||
* @param numTrees the number of trees in the forest
|
||||
* @param maxSize the max size of each tree
|
||||
* @param similarityFunction the distance function to use
|
||||
*/
|
||||
public RPForest(int numTrees,int maxSize,String similarityFunction) {
|
||||
this.numTrees = numTrees;
|
||||
this.maxSize = maxSize;
|
||||
this.similarityFunction = similarityFunction;
|
||||
trees = new ArrayList<>(numTrees);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Build the trees from the given dataset
|
||||
* @param x the input dataset (should be a 2d matrix)
|
||||
*/
|
||||
public void fit(INDArray x) {
|
||||
this.data = x;
|
||||
for(int i = 0; i < numTrees; i++) {
|
||||
RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction);
|
||||
tree.buildTree(x);
|
||||
trees.add(tree);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get all candidates relative to a specific datapoint.
|
||||
* @param input
|
||||
* @return
|
||||
*/
|
||||
public INDArray getAllCandidates(INDArray input) {
|
||||
return RPUtils.getAllCandidates(input,trees,similarityFunction);
|
||||
}
|
||||
|
||||
/**
|
||||
* Query results up to length n
|
||||
* nearest neighbors
|
||||
* @param toQuery the query item
|
||||
* @param n the number of nearest neighbors for the given data point
|
||||
* @return the indices for the nearest neighbors
|
||||
*/
|
||||
public INDArray queryAll(INDArray toQuery,int n) {
|
||||
return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Query all with the distances
|
||||
* sorted by index
|
||||
* @param query the query vector
|
||||
* @param numResults the number of results to return
|
||||
* @return a list of samples
|
||||
*/
|
||||
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
|
||||
return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,57 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.randomprojection;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
@Data
|
||||
public class RPHyperPlanes {
|
||||
private int dim;
|
||||
private INDArray wholeHyperPlane;
|
||||
|
||||
public RPHyperPlanes(int dim) {
|
||||
this.dim = dim;
|
||||
}
|
||||
|
||||
public INDArray getHyperPlaneAt(int depth) {
|
||||
if(wholeHyperPlane.isVector())
|
||||
return wholeHyperPlane;
|
||||
return wholeHyperPlane.slice(depth);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Add a new random element to the hyper plane.
|
||||
*/
|
||||
public void addRandomHyperPlane() {
|
||||
INDArray newPlane = Nd4j.randn(new int[] {1,dim});
|
||||
newPlane.divi(newPlane.normmaxNumber());
|
||||
if(wholeHyperPlane == null)
|
||||
wholeHyperPlane = newPlane;
|
||||
else {
|
||||
wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,48 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.randomprojection;
|
||||
|
||||
|
||||
import lombok.Data;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.Future;
|
||||
|
||||
@Data
|
||||
public class RPNode {
|
||||
private int depth;
|
||||
private RPNode left,right;
|
||||
private Future<RPNode> leftFuture,rightFuture;
|
||||
private List<Integer> indices;
|
||||
private double median;
|
||||
private RPTree tree;
|
||||
|
||||
|
||||
public RPNode(RPTree tree,int depth) {
|
||||
this.depth = depth;
|
||||
this.tree = tree;
|
||||
indices = new ArrayList<>();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,130 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.randomprojection;
|
||||
|
||||
import lombok.Builder;
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
|
||||
@Data
|
||||
public class RPTree {
|
||||
private RPNode root;
|
||||
private RPHyperPlanes rpHyperPlanes;
|
||||
private int dim;
|
||||
//also knows as leave size
|
||||
private int maxSize;
|
||||
private INDArray X;
|
||||
private String similarityFunction = "euclidean";
|
||||
private WorkspaceConfiguration workspaceConfiguration;
|
||||
private ExecutorService searchExecutor;
|
||||
private int searchWorkers;
|
||||
|
||||
/**
|
||||
*
|
||||
* @param dim the dimension of the vectors
|
||||
* @param maxSize the max size of the leaves
|
||||
*
|
||||
*/
|
||||
@Builder
|
||||
public RPTree(int dim, int maxSize,String similarityFunction) {
|
||||
this.dim = dim;
|
||||
this.maxSize = maxSize;
|
||||
rpHyperPlanes = new RPHyperPlanes(dim);
|
||||
root = new RPNode(this,0);
|
||||
this.similarityFunction = similarityFunction;
|
||||
workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
|
||||
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policySpill(SpillPolicy.REALLOCATE).build();
|
||||
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param dim the dimension of the vectors
|
||||
* @param maxSize the max size of the leaves
|
||||
*
|
||||
*/
|
||||
public RPTree(int dim, int maxSize) {
|
||||
this(dim,maxSize,"euclidean");
|
||||
}
|
||||
|
||||
/**
|
||||
* Build the tree with the given input data
|
||||
* @param x
|
||||
*/
|
||||
|
||||
public void buildTree(INDArray x) {
|
||||
this.X = x;
|
||||
for(int i = 0; i < x.rows(); i++) {
|
||||
root.getIndices().add(i);
|
||||
}
|
||||
|
||||
|
||||
|
||||
RPUtils.buildTree(this,root,rpHyperPlanes,
|
||||
x,maxSize,0,similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
|
||||
public void addNodeAtIndex(int idx,INDArray toAdd) {
|
||||
RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction);
|
||||
query.getIndices().add(idx);
|
||||
}
|
||||
|
||||
|
||||
public List<RPNode> getLeaves() {
|
||||
List<RPNode> nodes = new ArrayList<>();
|
||||
RPUtils.scanForLeaves(nodes,getRoot());
|
||||
return nodes;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Query all with the distances
|
||||
* sorted by index
|
||||
* @param query the query vector
|
||||
* @param numResults the number of results to return
|
||||
* @return a list of samples
|
||||
*/
|
||||
public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) {
|
||||
return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction);
|
||||
}
|
||||
|
||||
public INDArray query(INDArray query,int numResults) {
|
||||
return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction);
|
||||
}
|
||||
|
||||
public List<Integer> getCandidates(INDArray target) {
|
||||
return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,481 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.randomprojection;
|
||||
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import lombok.val;
|
||||
import org.nd4j.autodiff.functions.DifferentialFunction;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.ReduceOp;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
public class RPUtils {
|
||||
|
||||
|
||||
private static ThreadLocal<Map<String,DifferentialFunction>> functionInstances = new ThreadLocal<>();
|
||||
|
||||
public static <T extends DifferentialFunction> DifferentialFunction getOp(String name,
|
||||
INDArray x,
|
||||
INDArray y,
|
||||
INDArray result) {
|
||||
Map<String,DifferentialFunction> ops = functionInstances.get();
|
||||
if(ops == null) {
|
||||
ops = new HashMap<>();
|
||||
functionInstances.set(ops);
|
||||
}
|
||||
|
||||
boolean allDistances = x.length() != y.length();
|
||||
|
||||
switch(name) {
|
||||
case "cosinedistance":
|
||||
if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances);
|
||||
ops.put(name,cosineDistance);
|
||||
return cosineDistance;
|
||||
}
|
||||
else {
|
||||
CosineDistance cosineDistance = (CosineDistance) ops.get(name);
|
||||
return cosineDistance;
|
||||
}
|
||||
case "cosinesimilarity":
|
||||
if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances);
|
||||
ops.put(name,cosineSimilarity);
|
||||
return cosineSimilarity;
|
||||
}
|
||||
else {
|
||||
CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name);
|
||||
cosineSimilarity.setX(x);
|
||||
cosineSimilarity.setY(y);
|
||||
cosineSimilarity.setZ(result);
|
||||
return cosineSimilarity;
|
||||
|
||||
}
|
||||
case "manhattan":
|
||||
if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances);
|
||||
ops.put(name,manhattanDistance);
|
||||
return manhattanDistance;
|
||||
}
|
||||
else {
|
||||
ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name);
|
||||
manhattanDistance.setX(x);
|
||||
manhattanDistance.setY(y);
|
||||
manhattanDistance.setZ(result);
|
||||
return manhattanDistance;
|
||||
}
|
||||
case "jaccard":
|
||||
if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances);
|
||||
ops.put(name,jaccardDistance);
|
||||
return jaccardDistance;
|
||||
}
|
||||
else {
|
||||
JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name);
|
||||
jaccardDistance.setX(x);
|
||||
jaccardDistance.setY(y);
|
||||
jaccardDistance.setZ(result);
|
||||
return jaccardDistance;
|
||||
}
|
||||
case "hamming":
|
||||
if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances);
|
||||
ops.put(name,hammingDistance);
|
||||
return hammingDistance;
|
||||
}
|
||||
else {
|
||||
HammingDistance hammingDistance = (HammingDistance) ops.get(name);
|
||||
hammingDistance.setX(x);
|
||||
hammingDistance.setY(y);
|
||||
hammingDistance.setZ(result);
|
||||
return hammingDistance;
|
||||
}
|
||||
//euclidean
|
||||
default:
|
||||
if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) {
|
||||
EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances);
|
||||
ops.put(name,euclideanDistance);
|
||||
return euclideanDistance;
|
||||
}
|
||||
else {
|
||||
EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name);
|
||||
euclideanDistance.setX(x);
|
||||
euclideanDistance.setY(y);
|
||||
euclideanDistance.setZ(result);
|
||||
return euclideanDistance;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Query all trees using the given input and data
|
||||
* @param toQuery the query vector
|
||||
* @param X the input data to query
|
||||
* @param trees the trees to query
|
||||
* @param n the number of results to search for
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @return the indices (in order) in the ndarray
|
||||
*/
|
||||
public static List<Pair<Double,Integer>> queryAllWithDistances(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
|
||||
if(trees.isEmpty()) {
|
||||
throw new ND4JIllegalArgumentException("Trees is empty!");
|
||||
}
|
||||
|
||||
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
|
||||
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
|
||||
int numReturns = Math.min(n,sortedCandidates.size());
|
||||
List<Pair<Double,Integer>> ret = new ArrayList<>(numReturns);
|
||||
for(int i = 0; i < numReturns; i++) {
|
||||
ret.add(sortedCandidates.get(i));
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
/**
|
||||
* Query all trees using the given input and data
|
||||
* @param toQuery the query vector
|
||||
* @param X the input data to query
|
||||
* @param trees the trees to query
|
||||
* @param n the number of results to search for
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @return the indices (in order) in the ndarray
|
||||
*/
|
||||
public static INDArray queryAll(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) {
|
||||
if(trees.isEmpty()) {
|
||||
throw new ND4JIllegalArgumentException("Trees is empty!");
|
||||
}
|
||||
|
||||
List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction);
|
||||
val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction);
|
||||
int numReturns = Math.min(n,sortedCandidates.size());
|
||||
|
||||
INDArray result = Nd4j.create(numReturns);
|
||||
for(int i = 0; i < numReturns; i++) {
|
||||
result.putScalar(i,sortedCandidates.get(i).getSecond());
|
||||
}
|
||||
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the sorted distances given the
|
||||
* query vector, input data, given the list of possible search candidates
|
||||
* @param x the query vector
|
||||
* @param X the input data to use
|
||||
* @param candidates the possible search candidates
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @return the sorted distances
|
||||
*/
|
||||
public static List<Pair<Double,Integer>> sortCandidates(INDArray x,INDArray X,
|
||||
List<Integer> candidates,
|
||||
String similarityFunction) {
|
||||
int prevIdx = -1;
|
||||
List<Pair<Double,Integer>> ret = new ArrayList<>();
|
||||
for(int i = 0; i < candidates.size(); i++) {
|
||||
if(candidates.get(i) != prevIdx) {
|
||||
ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i)));
|
||||
}
|
||||
|
||||
prevIdx = i;
|
||||
}
|
||||
|
||||
|
||||
Collections.sort(ret, new Comparator<Pair<Double, Integer>>() {
|
||||
@Override
|
||||
public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) {
|
||||
return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst());
|
||||
}
|
||||
});
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Get the search candidates as indices given the input
|
||||
* and similarity function
|
||||
* @param x the input data to search with
|
||||
* @param trees the trees to search
|
||||
* @param similarityFunction the function to use for similarity
|
||||
* @return the list of indices as the search results
|
||||
*/
|
||||
public static INDArray getAllCandidates(INDArray x,List<RPTree> trees,String similarityFunction) {
|
||||
List<Integer> candidates = getCandidates(x,trees,similarityFunction);
|
||||
Collections.sort(candidates);
|
||||
|
||||
int prevIdx = -1;
|
||||
int idxCount = 0;
|
||||
List<Pair<Integer,Integer>> scores = new ArrayList<>();
|
||||
for(int i = 0; i < candidates.size(); i++) {
|
||||
if(candidates.get(i) == prevIdx) {
|
||||
idxCount++;
|
||||
}
|
||||
else if(prevIdx != -1) {
|
||||
scores.add(Pair.of(idxCount,prevIdx));
|
||||
idxCount = 1;
|
||||
}
|
||||
|
||||
prevIdx = i;
|
||||
}
|
||||
|
||||
|
||||
scores.add(Pair.of(idxCount,prevIdx));
|
||||
|
||||
INDArray arr = Nd4j.create(scores.size());
|
||||
for(int i = 0; i < scores.size(); i++) {
|
||||
arr.putScalar(i,scores.get(i).getSecond());
|
||||
}
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Get the search candidates as indices given the input
|
||||
* and similarity function
|
||||
* @param x the input data to search with
|
||||
* @param roots the trees to search
|
||||
* @param similarityFunction the function to use for similarity
|
||||
* @return the list of indices as the search results
|
||||
*/
|
||||
public static List<Integer> getCandidates(INDArray x,List<RPTree> roots,String similarityFunction) {
|
||||
Set<Integer> ret = new LinkedHashSet<>();
|
||||
for(RPTree tree : roots) {
|
||||
RPNode root = tree.getRoot();
|
||||
RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction);
|
||||
ret.addAll(query.getIndices());
|
||||
}
|
||||
|
||||
return new ArrayList<>(ret);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Query the tree starting from the given node
|
||||
* using the given hyper plane and similarity function
|
||||
* @param from the node to start from
|
||||
* @param planes the hyper plane to query
|
||||
* @param x the input data
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @return the leaf node representing the given query from a
|
||||
* search in the tree
|
||||
*/
|
||||
public static RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) {
|
||||
if(from.getLeft() == null && from.getRight() == null) {
|
||||
return from;
|
||||
}
|
||||
|
||||
INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth());
|
||||
double dist = computeDistance(similarityFunction,x,hyperPlane);
|
||||
if(dist <= from.getMedian()) {
|
||||
return query(from.getLeft(),planes,x,similarityFunction);
|
||||
}
|
||||
|
||||
else {
|
||||
return query(from.getRight(),planes,x,similarityFunction);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Compute the distance between 2 vectors
|
||||
* given a function name. Valid function names:
|
||||
* euclidean: euclidean distance
|
||||
* cosinedistance: cosine distance
|
||||
* cosine similarity: cosine similarity
|
||||
* manhattan: manhattan distance
|
||||
* jaccard: jaccard distance
|
||||
* hamming: hamming distance
|
||||
* @param function the function to use (default euclidean distance)
|
||||
* @param x the first vector
|
||||
* @param y the second vector
|
||||
* @return the distance between the 2 vectors given the inputs
|
||||
*/
|
||||
public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) {
|
||||
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
|
||||
op.setDimensions(1);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
return op.z();
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
/**
|
||||
* Compute the distance between 2 vectors
|
||||
* given a function name. Valid function names:
|
||||
* euclidean: euclidean distance
|
||||
* cosinedistance: cosine distance
|
||||
* cosine similarity: cosine similarity
|
||||
* manhattan: manhattan distance
|
||||
* jaccard: jaccard distance
|
||||
* hamming: hamming distance
|
||||
* @param function the function to use (default euclidean distance)
|
||||
* @param x the first vector
|
||||
* @param y the second vector
|
||||
* @return the distance between the 2 vectors given the inputs
|
||||
*/
|
||||
public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) {
|
||||
ReduceOp op = (ReduceOp) getOp(function, x, y, result);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
return op.z().getDouble(0);
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the distance between 2 vectors
|
||||
* given a function name. Valid function names:
|
||||
* euclidean: euclidean distance
|
||||
* cosinedistance: cosine distance
|
||||
* cosine similarity: cosine similarity
|
||||
* manhattan: manhattan distance
|
||||
* jaccard: jaccard distance
|
||||
* hamming: hamming distance
|
||||
* @param function the function to use (default euclidean distance)
|
||||
* @param x the first vector
|
||||
* @param y the second vector
|
||||
* @return the distance between the 2 vectors given the inputs
|
||||
*/
|
||||
public static double computeDistance(String function,INDArray x,INDArray y) {
|
||||
return computeDistance(function,x,y,Nd4j.scalar(0.0));
|
||||
}
|
||||
|
||||
/**
|
||||
* Initialize the tree given the input parameters
|
||||
* @param tree the tree to initialize
|
||||
* @param from the starting node
|
||||
* @param planes the hyper planes to use (vector space for similarity)
|
||||
* @param X the input data
|
||||
* @param maxSize the max number of indices on a given leaf node
|
||||
* @param depth the current depth of the tree
|
||||
* @param similarityFunction the similarity function to use
|
||||
*/
|
||||
public static void buildTree(RPTree tree,
|
||||
RPNode from,
|
||||
RPHyperPlanes planes,
|
||||
INDArray X,
|
||||
int maxSize,
|
||||
int depth,
|
||||
String similarityFunction) {
|
||||
if(from.getIndices().size() <= maxSize) {
|
||||
//slimNode
|
||||
slimNode(from);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
List<Double> distances = new ArrayList<>();
|
||||
RPNode left = new RPNode(tree,depth + 1);
|
||||
RPNode right = new RPNode(tree,depth + 1);
|
||||
|
||||
if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) {
|
||||
planes.addRandomHyperPlane();
|
||||
}
|
||||
|
||||
|
||||
INDArray hyperPlane = planes.getHyperPlaneAt(depth);
|
||||
|
||||
|
||||
|
||||
for(int i = 0; i < from.getIndices().size(); i++) {
|
||||
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
|
||||
distances.add(cosineSim);
|
||||
}
|
||||
|
||||
Collections.sort(distances);
|
||||
from.setMedian(distances.get(distances.size() / 2));
|
||||
|
||||
|
||||
for(int i = 0; i < from.getIndices().size(); i++) {
|
||||
double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i)));
|
||||
if(cosineSim <= from.getMedian()) {
|
||||
left.getIndices().add(from.getIndices().get(i));
|
||||
}
|
||||
else {
|
||||
right.getIndices().add(from.getIndices().get(i));
|
||||
}
|
||||
}
|
||||
|
||||
//failed split
|
||||
if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) {
|
||||
slimNode(from);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
from.setLeft(left);
|
||||
from.setRight(right);
|
||||
slimNode(from);
|
||||
|
||||
|
||||
buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction);
|
||||
buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction);
|
||||
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Scan for leaves accumulating
|
||||
* the nodes in the passed in list
|
||||
* @param nodes the nodes so far
|
||||
* @param scan the tree to scan
|
||||
*/
|
||||
public static void scanForLeaves(List<RPNode> nodes,RPTree scan) {
|
||||
scanForLeaves(nodes,scan.getRoot());
|
||||
}
|
||||
|
||||
/**
|
||||
* Scan for leaves accumulating
|
||||
* the nodes in the passed in list
|
||||
* @param nodes the nodes so far
|
||||
*/
|
||||
public static void scanForLeaves(List<RPNode> nodes,RPNode current) {
|
||||
if(current.getLeft() == null && current.getRight() == null)
|
||||
nodes.add(current);
|
||||
if(current.getLeft() != null)
|
||||
scanForLeaves(nodes,current.getLeft());
|
||||
if(current.getRight() != null)
|
||||
scanForLeaves(nodes,current.getRight());
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Prune indices from the given node
|
||||
* when it's a leaf
|
||||
* @param node the node to prune
|
||||
*/
|
||||
public static void slimNode(RPNode node) {
|
||||
if(node.getRight() != null && node.getLeft() != null) {
|
||||
node.getIndices().clear();
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,87 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.sptree;
|
||||
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class Cell implements Serializable {
|
||||
private int dimension;
|
||||
private INDArray corner, width;
|
||||
|
||||
public Cell(int dimension) {
|
||||
this.dimension = dimension;
|
||||
}
|
||||
|
||||
public double corner(int d) {
|
||||
return corner.getDouble(d);
|
||||
}
|
||||
|
||||
public double width(int d) {
|
||||
return width.getDouble(d);
|
||||
}
|
||||
|
||||
public void setCorner(int d, double corner) {
|
||||
this.corner.putScalar(d, corner);
|
||||
}
|
||||
|
||||
public void setWidth(int d, double width) {
|
||||
this.width.putScalar(d, width);
|
||||
}
|
||||
|
||||
public void setWidth(INDArray width) {
|
||||
this.width = width;
|
||||
}
|
||||
|
||||
public void setCorner(INDArray corner) {
|
||||
this.corner = corner;
|
||||
}
|
||||
|
||||
|
||||
public boolean contains(INDArray point) {
|
||||
INDArray cornerMinusWidth = corner.sub(width);
|
||||
INDArray cornerPlusWidth = corner.add(width);
|
||||
for (int d = 0; d < dimension; d++) {
|
||||
double pointD = point.getDouble(d);
|
||||
if (cornerMinusWidth.getDouble(d) > pointD)
|
||||
return false;
|
||||
if (cornerPlusWidth.getDouble(d) < pointD)
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
public INDArray width() {
|
||||
return width;
|
||||
}
|
||||
|
||||
public INDArray corner() {
|
||||
return corner;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
|
@ -1,95 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.sptree;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class DataPoint implements Serializable {
|
||||
private int index;
|
||||
private INDArray point;
|
||||
private long d;
|
||||
private String functionName;
|
||||
private boolean invert = false;
|
||||
|
||||
|
||||
public DataPoint(int index, INDArray point, boolean invert) {
|
||||
this(index, point, "euclidean");
|
||||
this.invert = invert;
|
||||
}
|
||||
|
||||
public DataPoint(int index, INDArray point, String functionName, boolean invert) {
|
||||
this.index = index;
|
||||
this.point = point;
|
||||
this.functionName = functionName;
|
||||
this.d = point.length();
|
||||
this.invert = invert;
|
||||
}
|
||||
|
||||
|
||||
public DataPoint(int index, INDArray point) {
|
||||
this(index, point, false);
|
||||
}
|
||||
|
||||
public DataPoint(int index, INDArray point, String functionName) {
|
||||
this(index, point, functionName, false);
|
||||
}
|
||||
|
||||
/**
|
||||
* Euclidean distance
|
||||
* @param point the distance from this point to the given point
|
||||
* @return the distance between the two points
|
||||
*/
|
||||
public float distance(DataPoint point) {
|
||||
switch (functionName) {
|
||||
case "euclidean":
|
||||
float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
|
||||
.getFinalResult().floatValue();
|
||||
return invert ? -ret : ret;
|
||||
|
||||
case "cosinesimilarity":
|
||||
float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point))
|
||||
.getFinalResult().floatValue();
|
||||
return invert ? -ret2 : ret2;
|
||||
|
||||
case "manhattan":
|
||||
float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point))
|
||||
.getFinalResult().floatValue();
|
||||
return invert ? -ret3 : ret3;
|
||||
case "dot":
|
||||
float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point);
|
||||
return invert ? -dotRet : dotRet;
|
||||
default:
|
||||
float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point))
|
||||
.getFinalResult().floatValue();
|
||||
return invert ? -ret4 : ret4;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,83 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.sptree;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class HeapItem implements Serializable, Comparable<HeapItem> {
|
||||
private int index;
|
||||
private double distance;
|
||||
|
||||
|
||||
public HeapItem(int index, double distance) {
|
||||
this.index = index;
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
public int getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
public void setIndex(int index) {
|
||||
this.index = index;
|
||||
}
|
||||
|
||||
public double getDistance() {
|
||||
return distance;
|
||||
}
|
||||
|
||||
public void setDistance(double distance) {
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || getClass() != o.getClass())
|
||||
return false;
|
||||
|
||||
HeapItem heapItem = (HeapItem) o;
|
||||
|
||||
if (index != heapItem.index)
|
||||
return false;
|
||||
return Double.compare(heapItem.distance, distance) == 0;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result;
|
||||
long temp;
|
||||
result = index;
|
||||
temp = Double.doubleToLongBits(distance);
|
||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(HeapItem o) {
|
||||
return distance < o.distance ? 1 : 0;
|
||||
}
|
||||
}
|
|
@ -1,72 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.sptree;
|
||||
|
||||
import lombok.Data;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@Data
|
||||
public class HeapObject implements Serializable, Comparable<HeapObject> {
|
||||
private int index;
|
||||
private INDArray point;
|
||||
private double distance;
|
||||
|
||||
|
||||
public HeapObject(int index, INDArray point, double distance) {
|
||||
this.index = index;
|
||||
this.point = point;
|
||||
this.distance = distance;
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public boolean equals(Object o) {
|
||||
if (this == o)
|
||||
return true;
|
||||
if (o == null || getClass() != o.getClass())
|
||||
return false;
|
||||
|
||||
HeapObject heapObject = (HeapObject) o;
|
||||
|
||||
if (!point.equals(heapObject.point))
|
||||
return false;
|
||||
|
||||
return Double.compare(heapObject.distance, distance) == 0;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
int result;
|
||||
long temp;
|
||||
result = index;
|
||||
temp = Double.doubleToLongBits(distance);
|
||||
result = 31 * result + (int) (temp ^ (temp >>> 32));
|
||||
return result;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int compareTo(HeapObject o) {
|
||||
return distance < o.distance ? 1 : 0;
|
||||
}
|
||||
}
|
|
@ -1,425 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.sptree;
|
||||
|
||||
import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Set;
|
||||
|
||||
|
||||
/**
|
||||
* @author Adam Gibson
|
||||
*/
|
||||
public class SpTree implements Serializable {
|
||||
|
||||
|
||||
public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL";
|
||||
|
||||
|
||||
private int D;
|
||||
private INDArray data;
|
||||
public final static int NODE_RATIO = 8000;
|
||||
private int N;
|
||||
private int size;
|
||||
private int cumSize;
|
||||
private Cell boundary;
|
||||
private INDArray centerOfMass;
|
||||
private SpTree parent;
|
||||
private int[] index;
|
||||
private int nodeCapacity;
|
||||
private int numChildren = 2;
|
||||
private boolean isLeaf = true;
|
||||
private Collection<INDArray> indices;
|
||||
private SpTree[] children;
|
||||
private static Logger log = LoggerFactory.getLogger(SpTree.class);
|
||||
private String similarityFunction = Distance.EUCLIDEAN.toString();
|
||||
|
||||
|
||||
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
||||
String similarityFunction) {
|
||||
init(parent, data, corner, width, indices, similarityFunction);
|
||||
}
|
||||
|
||||
|
||||
public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) {
|
||||
this.indices = indices;
|
||||
this.N = data.rows();
|
||||
this.D = data.columns();
|
||||
this.similarityFunction = similarityFunction;
|
||||
data = data.dup();
|
||||
INDArray meanY = data.mean(0);
|
||||
INDArray minY = data.min(0);
|
||||
INDArray maxY = data.max(0);
|
||||
INDArray width = Nd4j.create(data.dataType(), meanY.shape());
|
||||
for (int i = 0; i < width.length(); i++) {
|
||||
width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i),
|
||||
meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD);
|
||||
}
|
||||
|
||||
try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) {
|
||||
init(null, data, meanY, width, indices, similarityFunction);
|
||||
fill(N);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) {
|
||||
this(parent, data, corner, width, indices, "euclidean");
|
||||
}
|
||||
|
||||
|
||||
public SpTree(INDArray data, Collection<INDArray> indices) {
|
||||
this(data, indices, "euclidean");
|
||||
}
|
||||
|
||||
|
||||
|
||||
public SpTree(INDArray data) {
|
||||
this(data, new ArrayList<INDArray>());
|
||||
}
|
||||
|
||||
public MemoryWorkspace workspace() {
|
||||
return null;
|
||||
}
|
||||
|
||||
private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices,
|
||||
String similarityFunction) {
|
||||
|
||||
this.parent = parent;
|
||||
D = data.columns();
|
||||
N = data.rows();
|
||||
this.similarityFunction = similarityFunction;
|
||||
nodeCapacity = N % NODE_RATIO;
|
||||
index = new int[nodeCapacity];
|
||||
for (int d = 1; d < this.D; d++)
|
||||
numChildren *= 2;
|
||||
this.indices = indices;
|
||||
isLeaf = true;
|
||||
size = 0;
|
||||
cumSize = 0;
|
||||
children = new SpTree[numChildren];
|
||||
this.data = data;
|
||||
boundary = new Cell(D);
|
||||
boundary.setCorner(corner.dup());
|
||||
boundary.setWidth(width.dup());
|
||||
centerOfMass = Nd4j.create(data.dataType(), D);
|
||||
}
|
||||
|
||||
|
||||
|
||||
private boolean insert(int index) {
|
||||
/*MemoryWorkspace workspace =
|
||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
||||
workspaceConfigurationExternal,
|
||||
workspaceExternal);
|
||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
||||
|
||||
INDArray point = data.slice(index);
|
||||
/*boolean contains = false;
|
||||
SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains);
|
||||
Nd4j.getExecutioner().exec(op);
|
||||
op.getOutputArgument(0).getScalar(0);
|
||||
if (!contains) return false;*/
|
||||
if (!boundary.contains(point))
|
||||
return false;
|
||||
|
||||
|
||||
cumSize++;
|
||||
double mult1 = (double) (cumSize - 1) / (double) cumSize;
|
||||
double mult2 = 1.0 / (double) cumSize;
|
||||
centerOfMass.muli(mult1);
|
||||
centerOfMass.addi(point.mul(mult2));
|
||||
// If there is space in this quad tree and it is a leaf, add the object here
|
||||
if (isLeaf() && size < nodeCapacity) {
|
||||
this.index[size] = index;
|
||||
indices.add(point);
|
||||
size++;
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
for (int i = 0; i < size; i++) {
|
||||
INDArray compPoint = data.slice(this.index[i]);
|
||||
if (compPoint.equals(point))
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
if (isLeaf())
|
||||
subDivide();
|
||||
|
||||
|
||||
// Find out where the point can be inserted
|
||||
for (int i = 0; i < numChildren; i++) {
|
||||
if (children[i].insert(index))
|
||||
return true;
|
||||
}
|
||||
|
||||
throw new IllegalStateException("Shouldn't reach this state");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Subdivide the node in to
|
||||
* 4 children
|
||||
*/
|
||||
public void subDivide() {
|
||||
/*MemoryWorkspace workspace =
|
||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
||||
workspaceConfigurationExternal,
|
||||
workspaceExternal);
|
||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{
|
||||
|
||||
INDArray newCorner = Nd4j.create(data.dataType(), D);
|
||||
INDArray newWidth = Nd4j.create(data.dataType(), D);
|
||||
for (int i = 0; i < numChildren; i++) {
|
||||
int div = 1;
|
||||
for (int d = 0; d < D; d++) {
|
||||
newWidth.putScalar(d, .5 * boundary.width(d));
|
||||
if ((i / div) % 2 == 1)
|
||||
newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d));
|
||||
else
|
||||
newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d));
|
||||
div *= 2;
|
||||
}
|
||||
|
||||
children[i] = new SpTree(this, data, newCorner, newWidth, indices);
|
||||
|
||||
}
|
||||
|
||||
// Move existing points to correct children
|
||||
for (int i = 0; i < size; i++) {
|
||||
boolean success = false;
|
||||
for (int j = 0; j < this.numChildren; j++)
|
||||
if (!success)
|
||||
success = children[j].insert(index[i]);
|
||||
|
||||
index[i] = -1;
|
||||
}
|
||||
|
||||
// Empty parent node
|
||||
size = 0;
|
||||
isLeaf = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
* Compute non edge forces using barnes hut
|
||||
* @param pointIndex
|
||||
* @param theta
|
||||
* @param negativeForce
|
||||
* @param sumQ
|
||||
*/
|
||||
public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
|
||||
// Make sure that we spend no time on empty nodes or self-interactions
|
||||
INDArray buf = Nd4j.create(data.dataType(), this.D);
|
||||
|
||||
if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
|
||||
return;
|
||||
/* MemoryWorkspace workspace =
|
||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
||||
workspaceConfigurationExternal,
|
||||
workspaceExternal);
|
||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
||||
|
||||
// Compute distance between point and center-of-mass
|
||||
data.slice(pointIndex).subi(centerOfMass, buf);
|
||||
|
||||
double D = Nd4j.getBlasWrapper().dot(buf, buf);
|
||||
// Check whether we can use this node as a "summary"
|
||||
double maxWidth = boundary.width().maxNumber().doubleValue();
|
||||
// Check whether we can use this node as a "summary"
|
||||
if (isLeaf() || maxWidth / Math.sqrt(D) < theta) {
|
||||
|
||||
// Compute and add t-SNE force between point and current node
|
||||
double Q = 1.0 / (1.0 + D);
|
||||
double mult = cumSize * Q;
|
||||
sumQ.addAndGet(mult);
|
||||
mult *= Q;
|
||||
negativeForce.addi(buf.mul(mult));
|
||||
} else {
|
||||
|
||||
// Recursively apply Barnes-Hut to children
|
||||
for (int i = 0; i < numChildren; i++) {
|
||||
children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* Compute edge forces using barnes hut
|
||||
* @param rowP a vector
|
||||
* @param colP
|
||||
* @param valP
|
||||
* @param N the number of elements
|
||||
* @param posF the positive force
|
||||
*/
|
||||
public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
|
||||
if (!rowP.isVector())
|
||||
throw new IllegalArgumentException("RowP must be a vector");
|
||||
|
||||
// Loop over all edges in the graph
|
||||
// just execute native op
|
||||
Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF));
|
||||
|
||||
/*
|
||||
INDArray buf = Nd4j.create(data.dataType(), this.D);
|
||||
double D;
|
||||
for (int n = 0; n < N; n++) {
|
||||
INDArray slice = data.slice(n);
|
||||
for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
|
||||
|
||||
// Compute pairwise distance and Q-value
|
||||
slice.subi(data.slice(colP.getInt(i)), buf);
|
||||
|
||||
D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf);
|
||||
D = valP.getDouble(i) / D;
|
||||
|
||||
// Sum positive force
|
||||
posF.slice(n).addi(buf.muli(D));
|
||||
}
|
||||
}
|
||||
*/
|
||||
}
|
||||
|
||||
|
||||
|
||||
public boolean isLeaf() {
|
||||
return isLeaf;
|
||||
}
|
||||
|
||||
/**
|
||||
* Verifies the structure of the tree (does bounds checking on each node)
|
||||
* @return true if the structure of the tree
|
||||
* is correct.
|
||||
*/
|
||||
public boolean isCorrect() {
|
||||
/*MemoryWorkspace workspace =
|
||||
workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace()
|
||||
: Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(
|
||||
workspaceConfigurationExternal,
|
||||
workspaceExternal);
|
||||
try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ {
|
||||
|
||||
for (int n = 0; n < size; n++) {
|
||||
INDArray point = data.slice(index[n]);
|
||||
if (!boundary.contains(point))
|
||||
return false;
|
||||
}
|
||||
if (!isLeaf()) {
|
||||
boolean correct = true;
|
||||
for (int i = 0; i < numChildren; i++)
|
||||
correct = correct && children[i].isCorrect();
|
||||
return correct;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The depth of the node
|
||||
* @return the depth of the node
|
||||
*/
|
||||
public int depth() {
|
||||
if (isLeaf())
|
||||
return 1;
|
||||
int depth = 1;
|
||||
int maxChildDepth = 0;
|
||||
for (int i = 0; i < numChildren; i++) {
|
||||
maxChildDepth = Math.max(maxChildDepth, children[0].depth());
|
||||
}
|
||||
|
||||
return depth + maxChildDepth;
|
||||
}
|
||||
|
||||
private void fill(int n) {
|
||||
if (indices.isEmpty() && parent == null)
|
||||
for (int i = 0; i < n; i++) {
|
||||
log.trace("Inserted " + i);
|
||||
insert(i);
|
||||
}
|
||||
else
|
||||
log.warn("Called fill already");
|
||||
}
|
||||
|
||||
|
||||
public SpTree[] getChildren() {
|
||||
return children;
|
||||
}
|
||||
|
||||
public int getD() {
|
||||
return D;
|
||||
}
|
||||
|
||||
public INDArray getCenterOfMass() {
|
||||
return centerOfMass;
|
||||
}
|
||||
|
||||
public Cell getBoundary() {
|
||||
return boundary;
|
||||
}
|
||||
|
||||
public int[] getIndex() {
|
||||
return index;
|
||||
}
|
||||
|
||||
public int getCumSize() {
|
||||
return cumSize;
|
||||
}
|
||||
|
||||
public void setCumSize(int cumSize) {
|
||||
this.cumSize = cumSize;
|
||||
}
|
||||
|
||||
public int getNumChildren() {
|
||||
return numChildren;
|
||||
}
|
||||
|
||||
public void setNumChildren(int numChildren) {
|
||||
this.numChildren = numChildren;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,117 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.strategy;
|
||||
|
||||
import lombok.*;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
||||
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
|
||||
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
|
||||
|
||||
import java.io.Serializable;
|
||||
|
||||
@AllArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable {
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected ClusteringStrategyType type;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected Integer initialClusterCount;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected ClusteringAlgorithmCondition optimizationPhaseCondition;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected ClusteringAlgorithmCondition terminationCondition;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected boolean inverse;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected Distance distanceFunction;
|
||||
@Getter(AccessLevel.PUBLIC)
|
||||
@Setter(AccessLevel.PROTECTED)
|
||||
protected boolean allowEmptyClusters;
|
||||
|
||||
public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction,
|
||||
boolean allowEmptyClusters, boolean inverse) {
|
||||
this.type = type;
|
||||
this.initialClusterCount = initialClusterCount;
|
||||
this.distanceFunction = distanceFunction;
|
||||
this.allowEmptyClusters = allowEmptyClusters;
|
||||
this.inverse = inverse;
|
||||
}
|
||||
|
||||
public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount,
|
||||
Distance distanceFunction, boolean inverse) {
|
||||
this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param maxIterationCount
|
||||
* @return
|
||||
*/
|
||||
public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) {
|
||||
setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param rate
|
||||
* @return
|
||||
*/
|
||||
public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) {
|
||||
setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate));
|
||||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public boolean inverseDistanceCalculation() {
|
||||
return inverse;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param type
|
||||
* @return
|
||||
*/
|
||||
public boolean isStrategyOfType(ClusteringStrategyType type) {
|
||||
return type.equals(this.type);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
public Integer getInitialClusterCount() {
|
||||
return initialClusterCount;
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,102 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.strategy;
|
||||
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
public interface ClusteringStrategy {
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
boolean inverseDistanceCalculation();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ClusteringStrategyType getType();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param type
|
||||
* @return
|
||||
*/
|
||||
boolean isStrategyOfType(ClusteringStrategyType type);
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
Integer getInitialClusterCount();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
Distance getDistanceFunction();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
boolean isAllowEmptyClusters();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
ClusteringAlgorithmCondition getTerminationCondition();
|
||||
|
||||
/**
|
||||
*
|
||||
* @return
|
||||
*/
|
||||
boolean isOptimizationDefined();
|
||||
|
||||
/**
|
||||
*
|
||||
* @param iterationHistory
|
||||
* @return
|
||||
*/
|
||||
boolean isOptimizationApplicableNow(IterationHistory iterationHistory);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param maxIterationCount
|
||||
* @return
|
||||
*/
|
||||
BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount);
|
||||
|
||||
/**
|
||||
*
|
||||
* @param rate
|
||||
* @return
|
||||
*/
|
||||
BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate);
|
||||
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.strategy;
|
||||
|
||||
public enum ClusteringStrategyType {
|
||||
FIXED_CLUSTER_COUNT, OPTIMIZATION
|
||||
}
|
|
@ -1,68 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.strategy;
|
||||
|
||||
import lombok.AccessLevel;
|
||||
import lombok.NoArgsConstructor;
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
|
||||
/**
|
||||
*
|
||||
*/
|
||||
@NoArgsConstructor(access = AccessLevel.PROTECTED)
|
||||
public class FixedClusterCountStrategy extends BaseClusteringStrategy {
|
||||
|
||||
|
||||
protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction,
|
||||
boolean allowEmptyClusters, boolean inverse) {
|
||||
super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters,
|
||||
inverse);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param clusterCount
|
||||
* @param distanceFunction
|
||||
* @param inverse
|
||||
* @return
|
||||
*/
|
||||
public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) {
|
||||
return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse);
|
||||
}
|
||||
|
||||
/**
|
||||
* @return
|
||||
*/
|
||||
@Override
|
||||
public boolean inverseDistanceCalculation() {
|
||||
return inverse;
|
||||
}
|
||||
|
||||
public boolean isOptimizationDefined() {
|
||||
return false;
|
||||
}
|
||||
|
||||
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
|
||||
return false;
|
||||
}
|
||||
|
||||
}
|
|
@ -1,82 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.strategy;
|
||||
|
||||
import org.deeplearning4j.clustering.algorithm.Distance;
|
||||
import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition;
|
||||
import org.deeplearning4j.clustering.condition.ConvergenceCondition;
|
||||
import org.deeplearning4j.clustering.condition.FixedIterationCountCondition;
|
||||
import org.deeplearning4j.clustering.iteration.IterationHistory;
|
||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimization;
|
||||
import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
|
||||
|
||||
public class OptimisationStrategy extends BaseClusteringStrategy {
|
||||
public static int defaultIterationCount = 100;
|
||||
|
||||
private ClusteringOptimization clusteringOptimisation;
|
||||
private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition;
|
||||
|
||||
protected OptimisationStrategy() {
|
||||
super();
|
||||
}
|
||||
|
||||
protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) {
|
||||
super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false);
|
||||
}
|
||||
|
||||
public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) {
|
||||
return new OptimisationStrategy(initialClusterCount, distanceFunction);
|
||||
}
|
||||
|
||||
public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) {
|
||||
clusteringOptimisation = new ClusteringOptimization(type, value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) {
|
||||
clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value);
|
||||
return this;
|
||||
}
|
||||
|
||||
public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) {
|
||||
clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate);
|
||||
return this;
|
||||
}
|
||||
|
||||
|
||||
public double getClusteringOptimizationValue() {
|
||||
return clusteringOptimisation.getValue();
|
||||
}
|
||||
|
||||
public boolean isClusteringOptimizationType(ClusteringOptimizationType type) {
|
||||
return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type);
|
||||
}
|
||||
|
||||
public boolean isOptimizationDefined() {
|
||||
return clusteringOptimisation != null;
|
||||
}
|
||||
|
||||
public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) {
|
||||
return clusteringOptimisationApplicationCondition != null
|
||||
&& clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory);
|
||||
}
|
||||
|
||||
}
|
File diff suppressed because it is too large
Load Diff
|
@ -1,74 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.util;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.concurrent.*;
|
||||
|
||||
public class MultiThreadUtils {
|
||||
|
||||
private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class);
|
||||
|
||||
private static ExecutorService instance;
|
||||
|
||||
private MultiThreadUtils() {}
|
||||
|
||||
public static synchronized ExecutorService newExecutorService() {
|
||||
int nThreads = Runtime.getRuntime().availableProcessors();
|
||||
return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(),
|
||||
new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(Runnable r) {
|
||||
Thread t = Executors.defaultThreadFactory().newThread(r);
|
||||
t.setDaemon(true);
|
||||
return t;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) {
|
||||
int tasksCount = tasks.size();
|
||||
final CountDownLatch latch = new CountDownLatch(tasksCount);
|
||||
for (int i = 0; i < tasksCount; i++) {
|
||||
final int taskIdx = i;
|
||||
executorService.execute(new Runnable() {
|
||||
public void run() {
|
||||
try {
|
||||
tasks.get(taskIdx).run();
|
||||
} catch (Throwable e) {
|
||||
log.info("Unchecked exception thrown by task", e);
|
||||
} finally {
|
||||
latch.countDown();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
try {
|
||||
latch.await();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,61 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.util;
|
||||
|
||||
import java.util.Collection;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
public class SetUtils {
|
||||
private SetUtils() {}
|
||||
|
||||
// Set specific operations
|
||||
|
||||
public static <T> Set<T> intersection(Collection<T> parentCollection, Collection<T> removeFromCollection) {
|
||||
Set<T> results = new HashSet<>(parentCollection);
|
||||
results.retainAll(removeFromCollection);
|
||||
return results;
|
||||
}
|
||||
|
||||
public static <T> boolean intersectionP(Set<? extends T> s1, Set<? extends T> s2) {
|
||||
for (T elt : s1) {
|
||||
if (s2.contains(elt))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static <T> Set<T> union(Set<? extends T> s1, Set<? extends T> s2) {
|
||||
Set<T> s3 = new HashSet<>(s1);
|
||||
s3.addAll(s2);
|
||||
return s3;
|
||||
}
|
||||
|
||||
/** Return is s1 \ s2 */
|
||||
|
||||
public static <T> Set<T> difference(Collection<? extends T> s1, Collection<? extends T> s2) {
|
||||
Set<T> s3 = new HashSet<>(s1);
|
||||
s3.removeAll(s2);
|
||||
return s3;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1,633 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.vptree;
|
||||
|
||||
import lombok.*;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
import org.deeplearning4j.clustering.sptree.HeapObject;
|
||||
import org.deeplearning4j.clustering.util.MathUtils;
|
||||
import org.nd4j.linalg.api.memory.MemoryWorkspace;
|
||||
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
|
||||
import org.nd4j.linalg.api.memory.enums.*;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.api.ops.impl.reduce3.*;
|
||||
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.*;
|
||||
import java.util.concurrent.*;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
|
||||
@Slf4j
|
||||
@Builder
|
||||
@AllArgsConstructor
|
||||
public class VPTree implements Serializable {
|
||||
private static final long serialVersionUID = 1L;
|
||||
|
||||
public static final String EUCLIDEAN = "euclidean";
|
||||
private double tau;
|
||||
@Getter
|
||||
@Setter
|
||||
private INDArray items;
|
||||
private List<INDArray> itemsList;
|
||||
private Node root;
|
||||
private String similarityFunction;
|
||||
@Getter
|
||||
private boolean invert = false;
|
||||
private transient ExecutorService executorService;
|
||||
@Getter
|
||||
private int workers = 1;
|
||||
private AtomicInteger size = new AtomicInteger(0);
|
||||
|
||||
private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>();
|
||||
|
||||
private WorkspaceConfiguration workspaceConfiguration;
|
||||
|
||||
protected VPTree() {
|
||||
// method for serialization only
|
||||
scalars = new ThreadLocal<>();
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param points
|
||||
* @param invert
|
||||
*/
|
||||
public VPTree(INDArray points, boolean invert) {
|
||||
this(points, "euclidean", 1, invert);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param points
|
||||
* @param invert
|
||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
||||
*/
|
||||
public VPTree(INDArray points, boolean invert, int workers) {
|
||||
this(points, "euclidean", workers, invert);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items the items to use
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @param invert whether to invert the distance (similarity functions have different min/max objectives)
|
||||
*/
|
||||
public VPTree(INDArray items, String similarityFunction, boolean invert) {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.invert = invert;
|
||||
this.items = items;
|
||||
root = buildFromPoints(items);
|
||||
workers = 1;
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items the items to use
|
||||
* @param similarityFunction the similarity function to use
|
||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
||||
* @param invert whether to invert the metric (different optimization objective)
|
||||
*/
|
||||
public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) {
|
||||
this.workers = workers;
|
||||
|
||||
val list = new INDArray[items.size()];
|
||||
|
||||
// build list of INDArrays first
|
||||
for (int i = 0; i < items.size(); i++)
|
||||
list[i] = items.get(i).getPoint();
|
||||
//this.items.putRow(i, items.get(i).getPoint());
|
||||
|
||||
// just stack them out with concat :)
|
||||
this.items = Nd4j.pile(list);
|
||||
|
||||
this.invert = invert;
|
||||
this.similarityFunction = similarityFunction;
|
||||
root = buildFromPoints(this.items);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items
|
||||
* @param similarityFunction
|
||||
*/
|
||||
public VPTree(INDArray items, String similarityFunction) {
|
||||
this(items, similarityFunction, 1, false);
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items
|
||||
* @param similarityFunction
|
||||
* @param workers number of parallel workers for tree building (increases memory requirements!)
|
||||
* @param invert
|
||||
*/
|
||||
public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) {
|
||||
this.similarityFunction = similarityFunction;
|
||||
this.invert = invert;
|
||||
this.items = items;
|
||||
|
||||
this.workers = workers;
|
||||
root = buildFromPoints(items);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items
|
||||
* @param similarityFunction
|
||||
*/
|
||||
public VPTree(List<DataPoint> items, String similarityFunction) {
|
||||
this(items, similarityFunction, 1, false);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items
|
||||
*/
|
||||
public VPTree(INDArray items) {
|
||||
this(items, EUCLIDEAN);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param items
|
||||
*/
|
||||
public VPTree(List<DataPoint> items) {
|
||||
this(items, EUCLIDEAN);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create an ndarray
|
||||
* from the datapoints
|
||||
* @param data
|
||||
* @return
|
||||
*/
|
||||
public static INDArray buildFromData(List<DataPoint> data) {
|
||||
INDArray ret = Nd4j.create(data.size(), data.get(0).getD());
|
||||
for (int i = 0; i < ret.slices(); i++)
|
||||
ret.putSlice(i, data.get(i).getPoint());
|
||||
return ret;
|
||||
}
|
||||
|
||||
|
||||
|
||||
/**
|
||||
*
|
||||
* @param basePoint
|
||||
* @param distancesArr
|
||||
*/
|
||||
public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) {
|
||||
switch (similarityFunction) {
|
||||
case "euclidean":
|
||||
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1));
|
||||
break;
|
||||
case "cosinedistance":
|
||||
Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
case "cosinesimilarity":
|
||||
Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
case "manhattan":
|
||||
Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
case "dot":
|
||||
Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1));
|
||||
break;
|
||||
case "jaccard":
|
||||
Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
case "hamming":
|
||||
Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
default:
|
||||
Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1));
|
||||
break;
|
||||
|
||||
}
|
||||
|
||||
if (invert)
|
||||
distancesArr.negi();
|
||||
|
||||
}
|
||||
|
||||
public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) {
|
||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* Euclidean distance
|
||||
* @return the distance between the two points
|
||||
*/
|
||||
public double distance(INDArray arr1, INDArray arr2) {
|
||||
if (scalars == null)
|
||||
scalars = new ThreadLocal<>();
|
||||
|
||||
if (scalars.get() == null)
|
||||
scalars.set(Nd4j.scalar(arr1.dataType(), 0.0));
|
||||
|
||||
switch (similarityFunction) {
|
||||
case "jaccard":
|
||||
double ret7 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new JaccardDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret7 : ret7;
|
||||
case "hamming":
|
||||
double ret8 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new HammingDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret8 : ret8;
|
||||
case "euclidean":
|
||||
double ret = Nd4j.getExecutioner()
|
||||
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret : ret;
|
||||
case "cosinesimilarity":
|
||||
double ret2 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret2 : ret2;
|
||||
case "cosinedistance":
|
||||
double ret6 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new CosineDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret6 : ret6;
|
||||
case "manhattan":
|
||||
double ret3 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret3 : ret3;
|
||||
case "dot":
|
||||
double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2);
|
||||
return invert ? -dotRet : dotRet;
|
||||
default:
|
||||
double ret4 = Nd4j.getExecutioner()
|
||||
.execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get()))
|
||||
.getFinalResult().doubleValue();
|
||||
return invert ? -ret4 : ret4;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
protected class NodeBuilder implements Callable<Node> {
|
||||
protected List<INDArray> list;
|
||||
protected List<Integer> indices;
|
||||
|
||||
public NodeBuilder(List<INDArray> list, List<Integer> indices) {
|
||||
this.list = list;
|
||||
this.indices = indices;
|
||||
}
|
||||
|
||||
@Override
|
||||
public Node call() throws Exception {
|
||||
return buildFromPoints(list, indices);
|
||||
}
|
||||
}
|
||||
|
||||
private Node buildFromPoints(List<INDArray> points, List<Integer> indices) {
|
||||
Node ret = new Node(0, 0);
|
||||
|
||||
|
||||
// nothing to sort here
|
||||
if (points.size() == 1) {
|
||||
ret.point = points.get(0);
|
||||
ret.index = indices.get(0);
|
||||
return ret;
|
||||
}
|
||||
|
||||
// opening workspace, and creating it if that's the first call
|
||||
/* MemoryWorkspace workspace =
|
||||
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
|
||||
|
||||
INDArray items = Nd4j.vstack(points);
|
||||
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
|
||||
INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint);
|
||||
ret.point = basePoint;
|
||||
ret.index = indices.get(randomPoint);
|
||||
INDArray distancesArr = Nd4j.create(items.rows(), 1);
|
||||
|
||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
||||
|
||||
double medianDistance = distancesArr.medianNumber().doubleValue();
|
||||
|
||||
ret.threshold = (float) medianDistance;
|
||||
|
||||
List<INDArray> leftPoints = new ArrayList<>();
|
||||
List<Integer> leftIndices = new ArrayList<>();
|
||||
List<INDArray> rightPoints = new ArrayList<>();
|
||||
List<Integer> rightIndices = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < distancesArr.length(); i++) {
|
||||
if (i == randomPoint)
|
||||
continue;
|
||||
|
||||
if (distancesArr.getDouble(i) < medianDistance) {
|
||||
leftPoints.add(points.get(i));
|
||||
leftIndices.add(indices.get(i));
|
||||
} else {
|
||||
rightPoints.add(points.get(i));
|
||||
rightIndices.add(indices.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
// closing workspace
|
||||
//workspace.notifyScopeLeft();
|
||||
//log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes());
|
||||
|
||||
if (workers > 1) {
|
||||
if (!leftPoints.isEmpty())
|
||||
ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints);
|
||||
|
||||
if (!rightPoints.isEmpty())
|
||||
ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices));
|
||||
} else {
|
||||
if (!leftPoints.isEmpty())
|
||||
ret.left = buildFromPoints(leftPoints, leftIndices);
|
||||
|
||||
if (!rightPoints.isEmpty())
|
||||
ret.right = buildFromPoints(rightPoints, rightIndices);
|
||||
}
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
private Node buildFromPoints(INDArray items) {
|
||||
if (executorService == null && items == this.items && workers > 1) {
|
||||
final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
|
||||
|
||||
executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() {
|
||||
@Override
|
||||
public Thread newThread(final Runnable r) {
|
||||
Thread t = new Thread(new Runnable() {
|
||||
|
||||
@Override
|
||||
public void run() {
|
||||
Nd4j.getAffinityManager().unsafeSetDevice(deviceId);
|
||||
r.run();
|
||||
}
|
||||
});
|
||||
|
||||
t.setDaemon(true);
|
||||
t.setName("VPTree thread");
|
||||
|
||||
return t;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
final Node ret = new Node(0, 0);
|
||||
size.incrementAndGet();
|
||||
|
||||
/*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
|
||||
.policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
|
||||
.policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
|
||||
.policySpill(SpillPolicy.REALLOCATE).build();
|
||||
|
||||
// opening workspace
|
||||
MemoryWorkspace workspace =
|
||||
Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/
|
||||
|
||||
int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom());
|
||||
INDArray basePoint = items.getRow(randomPoint, true);
|
||||
INDArray distancesArr = Nd4j.create(items.rows(), 1);
|
||||
ret.point = basePoint;
|
||||
ret.index = randomPoint;
|
||||
|
||||
calcDistancesRelativeTo(items, basePoint, distancesArr);
|
||||
|
||||
double medianDistance = distancesArr.medianNumber().doubleValue();
|
||||
|
||||
ret.threshold = (float) medianDistance;
|
||||
|
||||
List<INDArray> leftPoints = new ArrayList<>();
|
||||
List<Integer> leftIndices = new ArrayList<>();
|
||||
List<INDArray> rightPoints = new ArrayList<>();
|
||||
List<Integer> rightIndices = new ArrayList<>();
|
||||
|
||||
for (int i = 0; i < distancesArr.length(); i++) {
|
||||
if (i == randomPoint)
|
||||
continue;
|
||||
|
||||
if (distancesArr.getDouble(i) < medianDistance) {
|
||||
leftPoints.add(items.getRow(i, true));
|
||||
leftIndices.add(i);
|
||||
} else {
|
||||
rightPoints.add(items.getRow(i, true));
|
||||
rightIndices.add(i);
|
||||
}
|
||||
}
|
||||
|
||||
// closing workspace
|
||||
//workspace.notifyScopeLeft();
|
||||
//workspace.destroyWorkspace(true);
|
||||
|
||||
if (!leftPoints.isEmpty())
|
||||
ret.left = buildFromPoints(leftPoints, leftIndices);
|
||||
|
||||
if (!rightPoints.isEmpty())
|
||||
ret.right = buildFromPoints(rightPoints, rightIndices);
|
||||
|
||||
// destroy once again
|
||||
//workspace.destroyWorkspace(true);
|
||||
|
||||
if (ret.left != null)
|
||||
ret.left.fetchFutures();
|
||||
|
||||
if (ret.right != null)
|
||||
ret.right.fetchFutures();
|
||||
|
||||
if (executorService != null)
|
||||
executorService.shutdown();
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) {
|
||||
search(target, k, results, distances, true);
|
||||
}
|
||||
|
||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
|
||||
boolean filterEqual) {
|
||||
search(target, k, results, distances, filterEqual, false);
|
||||
}
|
||||
/**
|
||||
*
|
||||
* @param target
|
||||
* @param k
|
||||
* @param results
|
||||
* @param distances
|
||||
*/
|
||||
public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances,
|
||||
boolean filterEqual, boolean dropEdge) {
|
||||
if (items != null)
|
||||
if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1)
|
||||
throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", "
|
||||
+ items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead");
|
||||
|
||||
k = Math.min(k, items.rows());
|
||||
results.clear();
|
||||
distances.clear();
|
||||
|
||||
PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator());
|
||||
|
||||
search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE);
|
||||
|
||||
while (!pq.isEmpty()) {
|
||||
HeapObject ho = pq.peek();
|
||||
results.add(new DataPoint(ho.getIndex(), ho.getPoint()));
|
||||
distances.add(ho.getDistance());
|
||||
pq.poll();
|
||||
}
|
||||
|
||||
Collections.reverse(results);
|
||||
Collections.reverse(distances);
|
||||
|
||||
if (dropEdge || results.size() > k) {
|
||||
if (filterEqual && distances.get(0) == 0.0) {
|
||||
results.remove(0);
|
||||
distances.remove(0);
|
||||
}
|
||||
|
||||
while (results.size() > k) {
|
||||
results.remove(results.size() - 1);
|
||||
distances.remove(distances.size() - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param node
|
||||
* @param target
|
||||
* @param k
|
||||
* @param pq
|
||||
*/
|
||||
public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) {
|
||||
|
||||
if (node == null)
|
||||
return;
|
||||
|
||||
double tau = cTau;
|
||||
|
||||
INDArray get = node.getPoint(); //items.getRow(node.getIndex());
|
||||
double distance = distance(get, target);
|
||||
if (distance < tau) {
|
||||
if (pq.size() == k)
|
||||
pq.poll();
|
||||
|
||||
pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance));
|
||||
if (pq.size() == k)
|
||||
tau = pq.peek().getDistance();
|
||||
}
|
||||
|
||||
Node left = node.getLeft();
|
||||
Node right = node.getRight();
|
||||
|
||||
if (left == null && right == null)
|
||||
return;
|
||||
|
||||
if (distance < node.getThreshold()) {
|
||||
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first
|
||||
search(left, target, k, pq, tau);
|
||||
}
|
||||
|
||||
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child
|
||||
search(right, target, k, pq, tau);
|
||||
}
|
||||
|
||||
} else {
|
||||
if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first
|
||||
search(right, target, k, pq, tau);
|
||||
}
|
||||
|
||||
if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child
|
||||
search(left, target, k, pq, tau);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
protected class HeapObjectComparator implements Comparator<HeapObject> {
|
||||
|
||||
@Override
|
||||
public int compare(HeapObject o1, HeapObject o2) {
|
||||
return Double.compare(o2.getDistance(), o1.getDistance());
|
||||
}
|
||||
}
|
||||
|
||||
@Data
|
||||
public static class Node implements Serializable {
|
||||
private static final long serialVersionUID = 2L;
|
||||
|
||||
private int index;
|
||||
private float threshold;
|
||||
private Node left, right;
|
||||
private INDArray point;
|
||||
protected transient Future<Node> futureLeft;
|
||||
protected transient Future<Node> futureRight;
|
||||
|
||||
public Node(int index, float threshold) {
|
||||
this.index = index;
|
||||
this.threshold = threshold;
|
||||
}
|
||||
|
||||
|
||||
public void fetchFutures() {
|
||||
try {
|
||||
if (futureLeft != null) {
|
||||
/*while (!futureLeft.isDone())
|
||||
Thread.sleep(100);*/
|
||||
|
||||
|
||||
left = futureLeft.get();
|
||||
}
|
||||
|
||||
if (futureRight != null) {
|
||||
/*while (!futureRight.isDone())
|
||||
Thread.sleep(100);*/
|
||||
|
||||
right = futureRight.get();
|
||||
}
|
||||
|
||||
|
||||
if (left != null)
|
||||
left.fetchFutures();
|
||||
|
||||
if (right != null)
|
||||
right.fetchFutures();
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
}
|
|
@ -1,79 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.vptree;
|
||||
|
||||
import lombok.Getter;
|
||||
import org.deeplearning4j.clustering.sptree.DataPoint;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class VPTreeFillSearch {
|
||||
private VPTree vpTree;
|
||||
private int k;
|
||||
@Getter
|
||||
private List<DataPoint> results;
|
||||
@Getter
|
||||
private List<Double> distances;
|
||||
private INDArray target;
|
||||
|
||||
public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) {
|
||||
this.vpTree = vpTree;
|
||||
this.k = k;
|
||||
this.target = target;
|
||||
}
|
||||
|
||||
public void search() {
|
||||
results = new ArrayList<>();
|
||||
distances = new ArrayList<>();
|
||||
//initial search
|
||||
//vpTree.search(target,k,results,distances);
|
||||
|
||||
//fill till there is k results
|
||||
//by going down the list
|
||||
// if(results.size() < k) {
|
||||
INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1);
|
||||
vpTree.calcDistancesRelativeTo(target, distancesArr);
|
||||
INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert());
|
||||
results.clear();
|
||||
distances.clear();
|
||||
if (vpTree.getItems().isVector()) {
|
||||
for (int i = 0; i < k; i++) {
|
||||
int idx = sortWithIndices[0].getInt(i);
|
||||
results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx))));
|
||||
distances.add(sortWithIndices[1].getDouble(idx));
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < k; i++) {
|
||||
int idx = sortWithIndices[0].getInt(i);
|
||||
results.add(new DataPoint(idx, vpTree.getItems().getRow(idx)));
|
||||
//distances.add(sortWithIndices[1].getDouble(idx));
|
||||
distances.add(sortWithIndices[1].getDouble(i));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.vptree;
|
|
@ -1,46 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.cluster;
|
||||
|
||||
import org.junit.Assert;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
public class ClusterSetTest {
|
||||
@Test
|
||||
public void testGetMostPopulatedClusters() {
|
||||
ClusterSet clusterSet = new ClusterSet(false);
|
||||
List<Cluster> clusters = new ArrayList<>();
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Cluster cluster = new Cluster();
|
||||
cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5)));
|
||||
clusters.add(cluster);
|
||||
}
|
||||
clusterSet.setClusters(clusters);
|
||||
List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5);
|
||||
for (int i = 0; i < 5; i++) {
|
||||
Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size());
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,422 +0,0 @@
|
|||
/*
|
||||
* ******************************************************************************
|
||||
* *
|
||||
* *
|
||||
* * This program and the accompanying materials are made available under the
|
||||
* * terms of the Apache License, Version 2.0 which is available at
|
||||
* * https://www.apache.org/licenses/LICENSE-2.0.
|
||||
* *
|
||||
* * See the NOTICE file distributed with this work for additional
|
||||
* * information regarding copyright ownership.
|
||||
* * Unless required by applicable law or agreed to in writing, software
|
||||
* * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
* * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
* * License for the specific language governing permissions and limitations
|
||||
* * under the License.
|
||||
* *
|
||||
* * SPDX-License-Identifier: Apache-2.0
|
||||
* *****************************************************************************
|
||||
*/
|
||||
|
||||
package org.deeplearning4j.clustering.kdtree;
|
||||
|
||||
import lombok.val;
|
||||
import org.deeplearning4j.BaseDL4JTest;
|
||||
import org.joda.time.Duration;
|
||||
import org.junit.Before;
|
||||
import org.junit.BeforeClass;
|
||||
import org.junit.Ignore;
|
||||
import org.junit.Test;
|
||||
import org.nd4j.linalg.api.buffer.DataType;
|
||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||
import org.nd4j.linalg.factory.Nd4j;
|
||||
import org.nd4j.common.primitives.Pair;
|
||||
import org.nd4j.shade.guava.base.Stopwatch;
|
||||
import org.nd4j.shade.guava.primitives.Doubles;
|
||||
import org.nd4j.shade.guava.primitives.Floats;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
import static java.util.concurrent.TimeUnit.MILLISECONDS;
|
||||
import static java.util.concurrent.TimeUnit.SECONDS;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
|
||||
public class KDTreeTest extends BaseDL4JTest {
|
||||
|
||||
@Override
|
||||
public long getTimeoutMilliseconds() {
|
||||
return 120000L;
|
||||
}
|
||||
|
||||
private KDTree kdTree;
|
||||
|
||||
@BeforeClass
|
||||
public static void beforeClass(){
|
||||
Nd4j.setDataType(DataType.FLOAT);
|
||||
}
|
||||
|
||||
@Before
|
||||
public void setUp() {
|
||||
kdTree = new KDTree(2);
|
||||
float[] data = new float[]{7,2};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{5,4};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{2,3};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{4,7};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{9,6};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{8,1};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testTree() {
|
||||
KDTree tree = new KDTree(2);
|
||||
INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT);
|
||||
INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT);
|
||||
tree.insert(half);
|
||||
tree.insert(one);
|
||||
Pair<Double, INDArray> pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT));
|
||||
assertEquals(half, pair.getValue());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testInsert() {
|
||||
int elements = 10;
|
||||
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
|
||||
|
||||
KDTree kdTree = new KDTree(digits.size());
|
||||
List<List<Double>> lists = new ArrayList<>();
|
||||
for (int i = 0; i < elements; i++) {
|
||||
List<Double> thisList = new ArrayList<>(digits.size());
|
||||
for (int k = 0; k < digits.size(); k++) {
|
||||
thisList.add(digits.get(k) + i);
|
||||
}
|
||||
lists.add(thisList);
|
||||
}
|
||||
|
||||
for (int i = 0; i < elements; i++) {
|
||||
double[] features = Doubles.toArray(lists.get(i));
|
||||
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
|
||||
kdTree.insert(ind);
|
||||
assertEquals(i + 1, kdTree.size());
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testDelete() {
|
||||
int elements = 10;
|
||||
List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0);
|
||||
|
||||
KDTree kdTree = new KDTree(digits.size());
|
||||
List<List<Double>> lists = new ArrayList<>();
|
||||
for (int i = 0; i < elements; i++) {
|
||||
List<Double> thisList = new ArrayList<>(digits.size());
|
||||
for (int k = 0; k < digits.size(); k++) {
|
||||
thisList.add(digits.get(k) + i);
|
||||
}
|
||||
lists.add(thisList);
|
||||
}
|
||||
|
||||
INDArray toDelete = Nd4j.empty(DataType.DOUBLE),
|
||||
leafToDelete = Nd4j.empty(DataType.DOUBLE);
|
||||
for (int i = 0; i < elements; i++) {
|
||||
double[] features = Doubles.toArray(lists.get(i));
|
||||
INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT);
|
||||
if (i == 1)
|
||||
toDelete = ind;
|
||||
if (i == elements - 1) {
|
||||
leafToDelete = ind;
|
||||
}
|
||||
kdTree.insert(ind);
|
||||
assertEquals(i + 1, kdTree.size());
|
||||
}
|
||||
|
||||
kdTree.delete(toDelete);
|
||||
assertEquals(9, kdTree.size());
|
||||
kdTree.delete(leafToDelete);
|
||||
assertEquals(8, kdTree.size());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNN() {
|
||||
int n = 10;
|
||||
|
||||
// make a KD-tree of dimension {#n}
|
||||
KDTree kdTree = new KDTree(n);
|
||||
for (int i = -1; i < n; i++) {
|
||||
// Insert a unit vector along each dimension
|
||||
List<Double> vec = new ArrayList<>(n);
|
||||
// i = -1 ensures the origin is in the Tree
|
||||
for (int k = 0; k < n; k++) {
|
||||
vec.add((k == i) ? 1.0 : 0.0);
|
||||
}
|
||||
INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT);
|
||||
kdTree.insert(indVec);
|
||||
}
|
||||
Random rand = new Random();
|
||||
|
||||
// random point in the Hypercube
|
||||
List<Double> pt = new ArrayList(n);
|
||||
for (int k = 0; k < n; k++) {
|
||||
pt.add(rand.nextDouble());
|
||||
}
|
||||
Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT));
|
||||
|
||||
// Always true for points in the unitary hypercube
|
||||
assertTrue(result.getKey() < Double.MAX_VALUE);
|
||||
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN() {
|
||||
int dimensions = 512;
|
||||
int vectorsNo = isIntegrationTests() ? 50000 : 1000;
|
||||
// make a KD-tree of dimension {#dimensions}
|
||||
Stopwatch stopwatch = Stopwatch.createStarted();
|
||||
KDTree kdTree = new KDTree(dimensions);
|
||||
for (int i = -1; i < vectorsNo; i++) {
|
||||
// Insert a unit vector along each dimension
|
||||
INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions);
|
||||
kdTree.insert(indVec);
|
||||
}
|
||||
stopwatch.stop();
|
||||
System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS));
|
||||
|
||||
Random rand = new Random();
|
||||
// random point in the Hypercube
|
||||
List<Double> pt = new ArrayList(dimensions);
|
||||
for (int k = 0; k < dimensions; k++) {
|
||||
pt.add(rand.nextFloat() * 10.0);
|
||||
}
|
||||
stopwatch.reset();
|
||||
stopwatch.start();
|
||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f);
|
||||
stopwatch.stop();
|
||||
System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_Simple() {
|
||||
int n = 2;
|
||||
KDTree kdTree = new KDTree(n);
|
||||
|
||||
float[] data = new float[]{3,3};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{1,1};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
data = new float[]{2,2};
|
||||
kdTree.insert(Nd4j.createFromArray(data));
|
||||
|
||||
data = new float[]{0,0};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f);
|
||||
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
|
||||
assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
|
||||
assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_1() {
|
||||
|
||||
assertEquals(6, kdTree.size());
|
||||
|
||||
float[] data = new float[]{8,1};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_2() {
|
||||
float[] data = new float[]{8, 1};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_3() {
|
||||
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
public void testKNN_4() {
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testKNN_5() {
|
||||
float[] data = new float[]{2, 3};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||
assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5);
|
||||
assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5);
|
||||
assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_KNN_6() {
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_KNN_7() {
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void test_KNN_8() {
|
||||
float[] data = new float[]{4, 6};
|
||||
List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f);
|
||||
assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5);
|
||||
assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5);
|
||||
assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testNoDuplicates() {
|
||||
int N = 100;
|
||||
KDTree bigTree = new KDTree(2);
|
||||
|
||||
List<INDArray> points = new ArrayList<>();
|
||||
for (int i = 0; i < N; ++i) {
|
||||
double[] data = new double[]{i, i};
|
||||
points.add(Nd4j.createFromArray(data));
|
||||
}
|
||||
|
||||
for (int i = 0; i < N; ++i) {
|
||||
bigTree.insert(points.get(i));
|
||||
}
|
||||
|
||||
assertEquals(N, bigTree.size());
|
||||
|
||||
INDArray node = Nd4j.empty(DataType.DOUBLE);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
node = bigTree.delete(node.isEmpty() ? points.get(i) : node);
|
||||
}
|
||||
|
||||
assertEquals(0, bigTree.size());
|
||||
}
|
||||
|
||||
@Ignore
|
||||
@Test
|
||||
public void performanceTest() {
|
||||
int n = 2;
|
||||
int num = 100000;
|
||||
// make a KD-tree of dimension {#n}
|
||||
long start = System.currentTimeMillis();
|
||||
KDTree kdTree = new KDTree(n);
|
||||
INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n);
|
||||
for (int i = 0 ; i < num; ++i) {
|
||||
kdTree.insert(inputArrray.getRow(i));
|
||||
}
|
||||
|
||||
long end = System.currentTimeMillis();
|
||||
Duration duration = new Duration(start, end);
|
||||
System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis());
|
||||
|
||||
List<Float> pt = new ArrayList(num);
|
||||
for (int k = 0; k < n; k++) {
|
||||
pt.add((float)(num / 2));
|
||||
}
|
||||
start = System.currentTimeMillis();
|
||||
List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f);
|
||||
end = System.currentTimeMillis();
|
||||
duration = new Duration(start, end);
|
||||
long elapsed = end - start;
|
||||
System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis());
|
||||
for (val pair : list) {
|
||||
System.out.println(pair.getFirst() + " " + pair.getSecond()) ;
|
||||
}
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue