Remove more unused modules
This commit is contained in:
		
							parent
							
								
									fa8537f0c7
								
							
						
					
					
						commit
						ee06fdd16f
					
				| @ -1,64 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-spark-inference-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-spark-inference-client</artifactId> | ||||
| 
 | ||||
|     <name>datavec-spark-inference-client</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark-inference-server_2.11</artifactId> | ||||
|             <version>1.0.0-SNAPSHOT</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark-inference-model</artifactId> | ||||
|             <version>${project.parent.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.mashape.unirest</groupId> | ||||
|             <artifactId>unirest-java</artifactId> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,292 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.client; | ||||
| 
 | ||||
| 
 | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import com.mashape.unirest.http.exceptions.UnirestException; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; | ||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| @Slf4j | ||||
| public class DataVecTransformClient implements DataVecTransformService { | ||||
|     private String url; | ||||
| 
 | ||||
|     static { | ||||
|         // Only one time | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (JsonProcessingException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transformProcess | ||||
|      */ | ||||
|     @Override | ||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { | ||||
|         try { | ||||
|             String s = transformProcess.toJson(); | ||||
|             Unirest.post(url + "/transformprocess").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").body(s).asJson(); | ||||
| 
 | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in setCSVTransformProcess()", e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public TransformProcess getCSVTransformProcess() { | ||||
|         try { | ||||
|             String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").asString().getBody(); | ||||
|             return TransformProcess.fromJson(s); | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in getCSVTransformProcess()",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ImageTransformProcess getImageTransformProcess() { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { | ||||
|         try { | ||||
|             SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .body(transform).asObject(SingleCSVRecord.class).getBody(); | ||||
|             return singleCsvRecord; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformIncremental(SingleCSVRecord)",e); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"TRUE") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(SequenceBatchCSVRecord.class) | ||||
|                     .getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"FALSE") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(BatchCSVRecord.class) | ||||
|                     .getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transform(BatchCSVRecord)", e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json").body(batchCSVRecord) | ||||
|                     .asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return batchArray1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformArray(BatchCSVRecord)",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") | ||||
|                     .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return array; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return array; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceArrayIncremental",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(Base64NDArrayBody.class).getBody(); | ||||
|             return batchArray1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceArray",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(batchCSVRecord) | ||||
|                     .asObject(SequenceBatchCSVRecord.class).getBody(); | ||||
|             return batchCSVRecord1; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequence"); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { | ||||
|         try { | ||||
|             SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") | ||||
|                     .header("accept", "application/json") | ||||
|                     .header("Content-Type", "application/json") | ||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") | ||||
|                     .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); | ||||
|             return singleCsvRecord; | ||||
|         } catch (UnirestException e) { | ||||
|             log.error("Error in transformSequenceIncremental"); | ||||
|         } | ||||
|         return null; | ||||
|     } | ||||
| } | ||||
| @ -1,45 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.transform.client; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.transform.client"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,139 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.transform.client; | ||||
| 
 | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; | ||||
| import org.datavec.spark.inference.client.DataVecTransformClient; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.net.ServerSocket; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assume.assumeNotNull; | ||||
| 
 | ||||
| public class DataVecTransformClientTest { | ||||
|     private static CSVSparkTransformServer server; | ||||
|     private static int port = getAvailablePort(); | ||||
|     private static DataVecTransformClient client; | ||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|     private static TransformProcess transformProcess = | ||||
|             new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); | ||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void beforeClass() throws Exception { | ||||
|         FileUtils.write(fileSave, transformProcess.toJson()); | ||||
|         fileSave.deleteOnExit(); | ||||
|         server = new CSVSparkTransformServer(); | ||||
|         server.runMain(new String[] {"-dp", String.valueOf(port)}); | ||||
| 
 | ||||
|         client = new DataVecTransformClient("http://localhost:" + port); | ||||
|         client.setCSVTransformProcess(transformProcess); | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void afterClass() throws Exception { | ||||
|         server.stop(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSequenceClient() { | ||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); | ||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); | ||||
|         List<BatchCSVRecord> batchCSVRecordList = new ArrayList<>(); | ||||
|         for(int i = 0; i < 5; i++) { | ||||
|              batchCSVRecordList.add(batchCSVRecord); | ||||
|         } | ||||
| 
 | ||||
|         sequenceBatchCSVRecord.add(batchCSVRecordList); | ||||
| 
 | ||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord); | ||||
|         assumeNotNull(sequenceBatchCSVRecord1); | ||||
| 
 | ||||
|         Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord); | ||||
|         assumeNotNull(array); | ||||
| 
 | ||||
|         Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord); | ||||
|         assumeNotNull(incrementalBody); | ||||
| 
 | ||||
|         Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord); | ||||
|         assumeNotNull(incrementalSequenceBody); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testRecord() throws Exception { | ||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); | ||||
|         SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord); | ||||
|         assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size()); | ||||
|         Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord); | ||||
|         INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
|         assumeNotNull(arr); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchRecord() throws Exception { | ||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); | ||||
|         BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord); | ||||
|         assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size()); | ||||
| 
 | ||||
|         Base64NDArrayBody body = client.transformArray(batchCSVRecord); | ||||
|         INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
|         assumeNotNull(arr); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static int getAvailablePort() { | ||||
|         try { | ||||
|             ServerSocket socket = new ServerSocket(0); | ||||
|             try { | ||||
|                 return socket.getLocalPort(); | ||||
|             } finally { | ||||
|                 socket.close(); | ||||
|             } | ||||
|         } catch (IOException e) { | ||||
|             throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,6 +0,0 @@ | ||||
| play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule | ||||
| play.modules.enabled += io.skymind.skil.service.PredictionModule | ||||
| play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk | ||||
| play.server.pidfile.path=/tmp/RUNNING_PID | ||||
| 
 | ||||
| play.server.http.port = 9600 | ||||
| @ -1,63 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-spark-inference-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-spark-inference-model</artifactId> | ||||
| 
 | ||||
|     <name>datavec-spark-inference-model</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-api</artifactId> | ||||
|             <version>${datavec.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-data-image</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-local</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,286 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Getter; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import lombok.val; | ||||
| import org.apache.arrow.memory.BufferAllocator; | ||||
| import org.apache.arrow.memory.RootAllocator; | ||||
| import org.apache.arrow.vector.FieldVector; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.util.ndarray.RecordConverter; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.arrow.ArrowConverter; | ||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; | ||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; | ||||
| import org.datavec.local.transforms.LocalTransformExecutor; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static org.datavec.arrow.ArrowConverter.*; | ||||
| import static org.datavec.local.transforms.LocalTransformExecutor.execute; | ||||
| import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| @Slf4j | ||||
| public class CSVSparkTransform { | ||||
|     @Getter | ||||
|     private TransformProcess transformProcess; | ||||
|     private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); | ||||
| 
 | ||||
|     /** | ||||
|      * Convert a raw record via | ||||
|      * the {@link TransformProcess} | ||||
|      * to a base 64ed ndarray | ||||
|      * @param batch the record to convert | ||||
|      * @return teh base 64ed ndarray | ||||
|      * @throws IOException | ||||
|      */ | ||||
|     public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException { | ||||
|         List<List<Writable>> converted =  execute(toArrowWritables(toArrowColumnsString( | ||||
|                 bufferAllocator,transformProcess.getInitialSchema(), | ||||
|                 batch.getRecordsAsString()), | ||||
|                 transformProcess.getInitialSchema()),transformProcess); | ||||
| 
 | ||||
|         ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted; | ||||
|         INDArray convert = ArrowConverter.toArray(arrowRecordBatch); | ||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Convert a raw record via | ||||
|      * the {@link TransformProcess} | ||||
|      * to a base 64ed ndarray | ||||
|      * @param record the record to convert | ||||
|      * @return the base 64ed ndarray | ||||
|      * @throws IOException | ||||
|      */ | ||||
|     public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException { | ||||
|         List<Writable> record2 = toArrowWritablesSingle( | ||||
|                 toArrowColumnsStringSingle(bufferAllocator, | ||||
|                         transformProcess.getInitialSchema(),record.getValues()), | ||||
|                 transformProcess.getInitialSchema()); | ||||
|         List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); | ||||
|         INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); | ||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Runs the transform process | ||||
|      * @param batch the record to transform | ||||
|      * @return the transformed record | ||||
|      */ | ||||
|     public BatchCSVRecord transform(BatchCSVRecord batch) { | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         List<List<Writable>> converted =  execute(toArrowWritables(toArrowColumnsString( | ||||
|                 bufferAllocator,transformProcess.getInitialSchema(), | ||||
|                 batch.getRecordsAsString()), | ||||
|                 transformProcess.getInitialSchema()),transformProcess); | ||||
|         int numCols = converted.get(0).size(); | ||||
|         for (int row = 0; row < converted.size(); row++) { | ||||
|             String[] values = new String[numCols]; | ||||
|             for (int i = 0; i < values.length; i++) | ||||
|                 values[i] = converted.get(row).get(i).toString(); | ||||
|             batchCSVRecord.add(new SingleCSVRecord(values)); | ||||
|         } | ||||
| 
 | ||||
|         return batchCSVRecord; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Runs the transform process | ||||
|      * @param record the record to transform | ||||
|      * @return the transformed record | ||||
|      */ | ||||
|     public SingleCSVRecord transform(SingleCSVRecord record) { | ||||
|         List<Writable> record2 = toArrowWritablesSingle( | ||||
|                 toArrowColumnsStringSingle(bufferAllocator, | ||||
|                         transformProcess.getInitialSchema(),record.getValues()), | ||||
|                 transformProcess.getInitialSchema()); | ||||
|         List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); | ||||
|         String[] values = new String[finalRecord.size()]; | ||||
|         for (int i = 0; i < values.length; i++) | ||||
|             values[i] = finalRecord.get(i).toString(); | ||||
|         return new SingleCSVRecord(values); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { | ||||
|         /** | ||||
|          * Sequence schema? | ||||
|          */ | ||||
|         List<List<List<Writable>>> converted = executeToSequence( | ||||
|                 toArrowWritables(toArrowColumnsStringTimeSeries( | ||||
|                         bufferAllocator, transformProcess.getInitialSchema(), | ||||
|                         Arrays.asList(transform.getRecordsAsString())), | ||||
|                         transformProcess.getInitialSchema()), transformProcess); | ||||
| 
 | ||||
|         SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); | ||||
|         for (int i = 0; i < converted.size(); i++) { | ||||
|             BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i)); | ||||
|             batchCSVRecord.add(Arrays.asList(batchCSVRecord1)); | ||||
|         } | ||||
| 
 | ||||
|         return batchCSVRecord; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchCSVRecordSequence | ||||
|      * @return | ||||
|      */ | ||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) { | ||||
|         List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString(); | ||||
|         boolean allSameLength = true; | ||||
|         Integer length = null; | ||||
|         for(List<List<String>> record : recordsAsString) { | ||||
|             if(length == null) { | ||||
|                 length = record.size(); | ||||
|             } | ||||
|             else if(record.size() != length)  { | ||||
|                 allSameLength = false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if(allSameLength) { | ||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString); | ||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors, | ||||
|                     transformProcess.getInitialSchema(), | ||||
|                     recordsAsString.get(0).get(0).size()); | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); | ||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); | ||||
|         } | ||||
| 
 | ||||
|         else { | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); | ||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * TODO: optimize | ||||
|      * @param batchCSVRecordSequence | ||||
|      * @return | ||||
|      */ | ||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) { | ||||
|         List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString(); | ||||
|         boolean allSameLength = true; | ||||
|         Integer length = null; | ||||
|         for(List<List<String>> record : strings) { | ||||
|             if(length == null) { | ||||
|                 length = record.size(); | ||||
|             } | ||||
|             else if(record.size() != length)  { | ||||
|                 allSameLength = false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if(allSameLength) { | ||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); | ||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); | ||||
|             INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); | ||||
|             try { | ||||
|                 return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); | ||||
|             } catch (IOException e) { | ||||
|                 throw new IllegalStateException(e); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         else { | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); | ||||
|             INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); | ||||
|             try { | ||||
|                 return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); | ||||
|             } catch (IOException e) { | ||||
|                 throw new IllegalStateException(e); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { | ||||
|         List<List<List<Writable>>> converted =  executeToSequence(toArrowWritables(toArrowColumnsString( | ||||
|                 bufferAllocator,transformProcess.getInitialSchema(), | ||||
|                 singleCsvRecord.getRecordsAsString()), | ||||
|                 transformProcess.getInitialSchema()),transformProcess); | ||||
|         ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted; | ||||
|         INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch); | ||||
|         try { | ||||
|             return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); | ||||
|         } catch (IOException e) { | ||||
|             log.error("",e); | ||||
|         } | ||||
| 
 | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString(); | ||||
|         boolean allSameLength = true; | ||||
|         Integer length = null; | ||||
|         for(List<List<String>> record : strings) { | ||||
|             if(length == null) { | ||||
|                 length = record.size(); | ||||
|             } | ||||
|             else if(record.size() != length)  { | ||||
|                 allSameLength = false; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         if(allSameLength) { | ||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); | ||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); | ||||
|              return SequenceBatchCSVRecord.fromWritables(transformed); | ||||
|         } | ||||
| 
 | ||||
|         else { | ||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); | ||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| } | ||||
| @ -1,64 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Getter; | ||||
| import org.datavec.image.data.ImageWritable; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| public class ImageSparkTransform { | ||||
|     @Getter | ||||
|     private ImageTransformProcess imageTransformProcess; | ||||
| 
 | ||||
|     public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException { | ||||
|         ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri()); | ||||
|         INDArray finalRecord = imageTransformProcess.executeArray(record2); | ||||
| 
 | ||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord)); | ||||
|     } | ||||
| 
 | ||||
|     public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException { | ||||
|         List<INDArray> records = new ArrayList<>(); | ||||
| 
 | ||||
|         for (SingleImageRecord imgRecord : batch.getRecords()) { | ||||
|             ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri()); | ||||
|             INDArray finalRecord = imageTransformProcess.executeArray(record2); | ||||
|             records.add(finalRecord); | ||||
|         } | ||||
| 
 | ||||
|         INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()])); | ||||
| 
 | ||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(array)); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,32 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class Base64NDArrayBody { | ||||
|     private String ndarray; | ||||
| } | ||||
| @ -1,104 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @Builder | ||||
| @NoArgsConstructor | ||||
| public class BatchCSVRecord implements Serializable { | ||||
|     private List<SingleCSVRecord> records; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Get the records as a list of strings | ||||
|      * (basically the underlying values for | ||||
|      * {@link SingleCSVRecord}) | ||||
|      * @return | ||||
|      */ | ||||
|     public List<List<String>> getRecordsAsString() { | ||||
|         if(records == null) | ||||
|             records = new ArrayList<>(); | ||||
|         List<List<String>> ret = new ArrayList<>(); | ||||
|         for(SingleCSVRecord csvRecord : records) { | ||||
|             ret.add(csvRecord.getValues()); | ||||
|         } | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Create a batch csv record | ||||
|      * from a list of writables. | ||||
|      * @param batch | ||||
|      * @return | ||||
|      */ | ||||
|     public static BatchCSVRecord fromWritables(List<List<Writable>> batch) { | ||||
|         List <SingleCSVRecord> records = new ArrayList<>(batch.size()); | ||||
|         for(List<Writable> list : batch) { | ||||
|             List<String> add = new ArrayList<>(list.size()); | ||||
|             for(Writable writable : list) { | ||||
|                 add.add(writable.toString()); | ||||
|             } | ||||
|             records.add(new SingleCSVRecord(add)); | ||||
|         } | ||||
| 
 | ||||
|         return BatchCSVRecord.builder().records(records).build(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Add a record | ||||
|      * @param record | ||||
|      */ | ||||
|     public void add(SingleCSVRecord record) { | ||||
|         if (records == null) | ||||
|             records = new ArrayList<>(); | ||||
|         records.add(record); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Return a batch record based on a dataset | ||||
|      * @param dataSet the dataset to get the batch record for | ||||
|      * @return the batch record | ||||
|      */ | ||||
|     public static BatchCSVRecord fromDataSet(DataSet dataSet) { | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < dataSet.numExamples(); i++) { | ||||
|             batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i))); | ||||
|         } | ||||
| 
 | ||||
|         return batchCSVRecord; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,50 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.net.URI; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class BatchImageRecord { | ||||
|     private List<SingleImageRecord> records; | ||||
| 
 | ||||
|     /** | ||||
|      * Add a record | ||||
|      * @param record | ||||
|      */ | ||||
|     public void add(SingleImageRecord record) { | ||||
|         if (records == null) | ||||
|             records = new ArrayList<>(); | ||||
|         records.add(record); | ||||
|     } | ||||
| 
 | ||||
|     public void add(URI uri) { | ||||
|         this.add(new SingleImageRecord(uri)); | ||||
|     } | ||||
| } | ||||
| @ -1,106 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| import org.nd4j.linalg.dataset.MultiDataSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @Builder | ||||
| @NoArgsConstructor | ||||
| public class SequenceBatchCSVRecord implements Serializable { | ||||
|     private List<List<BatchCSVRecord>> records; | ||||
| 
 | ||||
|     /** | ||||
|      * Add a record | ||||
|      * @param record | ||||
|      */ | ||||
|     public void add(List<BatchCSVRecord> record) { | ||||
|         if (records == null) | ||||
|             records = new ArrayList<>(); | ||||
|         records.add(record); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the records as a list of strings directly | ||||
|      * (this basically "unpacks" the objects) | ||||
|      * @return | ||||
|      */ | ||||
|     public List<List<List<String>>> getRecordsAsString() { | ||||
|         if(records == null) | ||||
|             Collections.emptyList(); | ||||
|         List<List<List<String>>> ret = new ArrayList<>(records.size()); | ||||
|         for(List<BatchCSVRecord> record : records) { | ||||
|             List<List<String>> add = new ArrayList<>(); | ||||
|             for(BatchCSVRecord batchCSVRecord : record) { | ||||
|                 for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) { | ||||
|                     add.add(singleCSVRecord.getValues()); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             ret.add(add); | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Convert a writables time series to a sequence batch | ||||
|      * @param input | ||||
|      * @return | ||||
|      */ | ||||
|     public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) { | ||||
|         SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord(); | ||||
|         for(int i = 0; i < input.size(); i++) { | ||||
|             ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i)))); | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Return a batch record based on a dataset | ||||
|      * @param dataSet the dataset to get the batch record for | ||||
|      * @return the batch record | ||||
|      */ | ||||
|     public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) { | ||||
|         SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); | ||||
|         for (int i = 0; i < dataSet.numFeatureArrays(); i++) { | ||||
|             batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i))))); | ||||
|         } | ||||
| 
 | ||||
|         return batchCSVRecord; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,95 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class SingleCSVRecord implements Serializable { | ||||
|     private List<String> values; | ||||
| 
 | ||||
|     /** | ||||
|      * Create from an array of values uses list internally) | ||||
|      * @param values | ||||
|      */ | ||||
|     public SingleCSVRecord(String...values) { | ||||
|         this.values = Arrays.asList(values); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Instantiate a csv record from a vector | ||||
|      * given either an input dataset and a | ||||
|      * one hot matrix, the index will be appended to | ||||
|      * the end of the record, or for regression | ||||
|      * it will append all values in the labels | ||||
|      * @param row the input vectors | ||||
|      * @return the record from this {@link DataSet} | ||||
|      */ | ||||
|     public static SingleCSVRecord fromRow(DataSet row) { | ||||
|         if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) | ||||
|             throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); | ||||
|         if (!row.getLabels().isVector() && !row.getLabels().isScalar()) | ||||
|             throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); | ||||
|         //classification | ||||
|         SingleCSVRecord record; | ||||
|         int idx = 0; | ||||
|         if (row.getLabels().sumNumber().doubleValue() == 1.0) { | ||||
|             String[] values = new String[row.getFeatures().columns() + 1]; | ||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); | ||||
|             } | ||||
|             int maxIdx = 0; | ||||
|             for (int i = 0; i < row.getLabels().length(); i++) { | ||||
|                 if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { | ||||
|                     maxIdx = i; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             values[idx++] = String.valueOf(maxIdx); | ||||
|             record = new SingleCSVRecord(values); | ||||
|         } | ||||
|         //regression (any number of values) | ||||
|         else { | ||||
|             String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; | ||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); | ||||
|             } | ||||
|             for (int i = 0; i < row.getLabels().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getLabels().getDouble(i)); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             record = new SingleCSVRecord(values); | ||||
| 
 | ||||
|         } | ||||
|         return record; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,34 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.net.URI; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class SingleImageRecord { | ||||
|     private URI uri; | ||||
| } | ||||
| @ -1,131 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.model.service; | ||||
| 
 | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| public interface DataVecTransformService { | ||||
| 
 | ||||
|     String SEQUENCE_OR_NOT_HEADER = "Sequence"; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param transformProcess | ||||
|      */ | ||||
|     void setCSVTransformProcess(TransformProcess transformProcess); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param imageTransformProcess | ||||
|      */ | ||||
|     void setImageTransformProcess(ImageTransformProcess imageTransformProcess); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     TransformProcess getCSVTransformProcess(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     ImageTransformProcess getImageTransformProcess(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord); | ||||
| 
 | ||||
|     SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     BatchCSVRecord transform(BatchCSVRecord batchCSVRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param singleImageRecord | ||||
|      * @return | ||||
|      * @throws IOException | ||||
|      */ | ||||
|     Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException; | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchImageRecord | ||||
|      * @return | ||||
|      * @throws IOException | ||||
|      */ | ||||
|     Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException; | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform); | ||||
| } | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.spark.transform"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,40 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class BatchCSVRecordTest { | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchRecordCreationFromDataSet() { | ||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet); | ||||
|         assertEquals(2, batchCSVRecord.getRecords().size()); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,212 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.api.transform.transform.integer.BaseIntegerTransform; | ||||
| import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; | ||||
| import org.datavec.api.writable.DoubleWritable; | ||||
| import org.datavec.api.writable.Text; | ||||
| import org.datavec.api.writable.Writable; | ||||
| import org.datavec.spark.inference.model.CSVSparkTransform; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| import static org.junit.Assert.*; | ||||
| 
 | ||||
| public class CSVSparkTransformTest { | ||||
|     @Test | ||||
|     public void testTransformer() throws Exception { | ||||
|         List<Writable> input = new ArrayList<>(); | ||||
|         input.add(new DoubleWritable(1.0)); | ||||
|         input.add(new DoubleWritable(2.0)); | ||||
| 
 | ||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|         List<Writable> output = new ArrayList<>(); | ||||
|         output.add(new Text("1.0")); | ||||
|         output.add(new Text("2.0")); | ||||
| 
 | ||||
|         TransformProcess transformProcess = | ||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); | ||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); | ||||
|         Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); | ||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
|         assertTrue(fromBase64.isVector()); | ||||
| //        System.out.println("Base 64ed array " + fromBase64); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTransformerBatch() throws Exception { | ||||
|         List<Writable> input = new ArrayList<>(); | ||||
|         input.add(new DoubleWritable(1.0)); | ||||
|         input.add(new DoubleWritable(2.0)); | ||||
| 
 | ||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|         List<Writable> output = new ArrayList<>(); | ||||
|         output.add(new Text("1.0")); | ||||
|         output.add(new Text("2.0")); | ||||
| 
 | ||||
|         TransformProcess transformProcess = | ||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); | ||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < 3; i++) | ||||
|             batchCSVRecord.add(record); | ||||
|         //data type is string, unable to convert | ||||
|         BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); | ||||
|       /*  Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1); | ||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
|         assertTrue(fromBase64.isMatrix()); | ||||
|         System.out.println("Base 64ed array " + fromBase64); */ | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSingleBatchSequence() throws Exception { | ||||
|         List<Writable> input = new ArrayList<>(); | ||||
|         input.add(new DoubleWritable(1.0)); | ||||
|         input.add(new DoubleWritable(2.0)); | ||||
| 
 | ||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|         List<Writable> output = new ArrayList<>(); | ||||
|         output.add(new Text("1.0")); | ||||
|         output.add(new Text("2.0")); | ||||
| 
 | ||||
|         TransformProcess transformProcess = | ||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); | ||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < 3; i++) | ||||
|             batchCSVRecord.add(record); | ||||
|         BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); | ||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); | ||||
|         sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); | ||||
|         Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); | ||||
|         INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray()); | ||||
| 
 | ||||
| 
 | ||||
|          //ensure accumulation | ||||
|         sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); | ||||
|         sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); | ||||
|         assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape()); | ||||
| 
 | ||||
|         SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); | ||||
|         assertNotNull(transformed.getRecords()); | ||||
| //        System.out.println(transformed); | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSpecificSequence() throws Exception { | ||||
|         final Schema schema = new Schema.Builder() | ||||
|                 .addColumnsString("action") | ||||
|                 .build(); | ||||
| 
 | ||||
|         final TransformProcess transformProcess = new TransformProcess.Builder(schema) | ||||
|                 .removeAllColumnsExceptFor("action") | ||||
|                 .transform(new ConverToLowercase("action")) | ||||
|                 .convertToSequence() | ||||
|                 .transform(new TextToCharacterIndexTransform("action", "action_sequence", | ||||
|                         defaultCharIndex(), false)) | ||||
|                 .integerToOneHot("action_sequence",0,29) | ||||
|                 .build(); | ||||
| 
 | ||||
|         final String[] data1 = new String[] { "test1" }; | ||||
|         final String[] data2 = new String[] { "test2" }; | ||||
|         final BatchCSVRecord batchCsvRecord = new BatchCSVRecord( | ||||
|                 Arrays.asList( | ||||
|                         new SingleCSVRecord(data1), | ||||
|                         new SingleCSVRecord(data2))); | ||||
| 
 | ||||
|         final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); | ||||
| //        System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); | ||||
|         transform.transformSequenceIncremental(batchCsvRecord); | ||||
|         assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     private static Map<Character,Integer> defaultCharIndex() { | ||||
|         Map<Character,Integer> ret = new TreeMap<>(); | ||||
| 
 | ||||
|         ret.put('a',0); | ||||
|         ret.put('b',1); | ||||
|         ret.put('c',2); | ||||
|         ret.put('d',3); | ||||
|         ret.put('e',4); | ||||
|         ret.put('f',5); | ||||
|         ret.put('g',6); | ||||
|         ret.put('h',7); | ||||
|         ret.put('i',8); | ||||
|         ret.put('j',9); | ||||
|         ret.put('k',10); | ||||
|         ret.put('l',11); | ||||
|         ret.put('m',12); | ||||
|         ret.put('n',13); | ||||
|         ret.put('o',14); | ||||
|         ret.put('p',15); | ||||
|         ret.put('q',16); | ||||
|         ret.put('r',17); | ||||
|         ret.put('s',18); | ||||
|         ret.put('t',19); | ||||
|         ret.put('u',20); | ||||
|         ret.put('v',21); | ||||
|         ret.put('w',22); | ||||
|         ret.put('x',23); | ||||
|         ret.put('y',24); | ||||
|         ret.put('z',25); | ||||
|         ret.put('/',26); | ||||
|         ret.put(' ',27); | ||||
|         ret.put('(',28); | ||||
|         ret.put(')',29); | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     public static class ConverToLowercase extends BaseIntegerTransform { | ||||
|         public ConverToLowercase(String column) { | ||||
|             super(column); | ||||
|         } | ||||
| 
 | ||||
|         public Text map(Writable writable) { | ||||
|             return new Text(writable.toString().toLowerCase()); | ||||
|         } | ||||
| 
 | ||||
|         public Object map(Object input) { | ||||
|             return new Text(input.toString().toLowerCase()); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,86 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.ImageSparkTransform; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.File; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class ImageSparkTransformTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testSingleImageSparkTransform() throws Exception { | ||||
|         int seed = 12345; | ||||
| 
 | ||||
|         File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); | ||||
| 
 | ||||
|         SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI()); | ||||
| 
 | ||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) | ||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); | ||||
| 
 | ||||
|         ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); | ||||
|         Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); | ||||
| 
 | ||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
| //        System.out.println("Base 64ed array " + fromBase64); | ||||
|         assertEquals(1, fromBase64.size(0)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testBatchImageSparkTransform() throws Exception { | ||||
|         int seed = 12345; | ||||
| 
 | ||||
|         File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); | ||||
|         File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile(); | ||||
|         File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile(); | ||||
| 
 | ||||
|         BatchImageRecord batch = new BatchImageRecord(); | ||||
|         batch.add(f0.toURI()); | ||||
|         batch.add(f1.toURI()); | ||||
|         batch.add(f2.toURI()); | ||||
| 
 | ||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) | ||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); | ||||
| 
 | ||||
|         ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); | ||||
|         Base64NDArrayBody body = imgSparkTransform.toArray(batch); | ||||
| 
 | ||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); | ||||
| //        System.out.println("Base 64ed array " + fromBase64); | ||||
|         assertEquals(3, fromBase64.size(0)); | ||||
|     } | ||||
| } | ||||
| @ -1,60 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.fail; | ||||
| 
 | ||||
| public class SingleCSVRecordTest { | ||||
| 
 | ||||
|     @Test(expected = IllegalArgumentException.class) | ||||
|     public void testVectorAssertion() { | ||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1)); | ||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet); | ||||
|         fail(singleCsvRecord.toString() + " should have thrown an exception"); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testVectorOneHotLabel() { | ||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}})); | ||||
| 
 | ||||
|         //assert | ||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); | ||||
|         assertEquals(3, singleCsvRecord.getValues().size()); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testVectorRegression() { | ||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); | ||||
| 
 | ||||
|         //assert | ||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); | ||||
|         assertEquals(4, singleCsvRecord.getValues().size()); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,47 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| 
 | ||||
| import java.io.File; | ||||
| 
 | ||||
| public class SingleImageRecordTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testImageRecord() throws Exception { | ||||
|         File f = testDir.newFolder(); | ||||
|         new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f); | ||||
|         File f0 = new File(f, "class0/0.jpg"); | ||||
|         File f1 = new File(f, "/class1/A.jpg"); | ||||
| 
 | ||||
|         SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI()); | ||||
| 
 | ||||
|         // need jackson test? | ||||
|     } | ||||
| } | ||||
| @ -1,154 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-spark-inference-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-spark-inference-server_2.11</artifactId> | ||||
| 
 | ||||
|     <name>datavec-spark-inference-server</name> | ||||
| 
 | ||||
|     <properties> | ||||
|         <!-- Default scala versions, may be overwritten by build profiles --> | ||||
|         <scala.version>2.11.12</scala.version> | ||||
|         <scala.binary.version>2.11</scala.binary.version> | ||||
|         <maven.compiler.source>1.8</maven.compiler.source> | ||||
|         <maven.compiler.target>1.8</maven.compiler.target> | ||||
|     </properties> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark-inference-model</artifactId> | ||||
|             <version>${datavec.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-spark_2.11</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.datavec</groupId> | ||||
|             <artifactId>datavec-data-image</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>joda-time</groupId> | ||||
|             <artifactId>joda-time</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.apache.commons</groupId> | ||||
|             <artifactId>commons-lang3</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.hibernate</groupId> | ||||
|             <artifactId>hibernate-validator</artifactId> | ||||
|             <version>${hibernate.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.scala-lang</groupId> | ||||
|             <artifactId>scala-library</artifactId> | ||||
|             <version>${scala.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.scala-lang</groupId> | ||||
|             <artifactId>scala-reflect</artifactId> | ||||
|             <version>${scala.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.play</groupId> | ||||
|             <artifactId>play-java_2.11</artifactId> | ||||
|             <version>${playframework.version}</version> | ||||
|             <exclusions> | ||||
|                 <exclusion> | ||||
|                     <groupId>com.google.code.findbugs</groupId> | ||||
|                     <artifactId>jsr305</artifactId> | ||||
|                 </exclusion> | ||||
|                 <exclusion> | ||||
|                     <groupId>net.jodah</groupId> | ||||
|                     <artifactId>typetools</artifactId> | ||||
|                 </exclusion> | ||||
|             </exclusions> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>net.jodah</groupId> | ||||
|             <artifactId>typetools</artifactId> | ||||
|             <version>${jodah.typetools.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.play</groupId> | ||||
|             <artifactId>play-json_2.11</artifactId> | ||||
|             <version>${playframework.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.play</groupId> | ||||
|             <artifactId>play-server_2.11</artifactId> | ||||
|             <version>${playframework.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.play</groupId> | ||||
|             <artifactId>play_2.11</artifactId> | ||||
|             <version>${playframework.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.play</groupId> | ||||
|             <artifactId>play-netty-server_2.11</artifactId> | ||||
|             <version>${playframework.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.typesafe.akka</groupId> | ||||
|             <artifactId>akka-cluster_2.11</artifactId> | ||||
|             <version>2.5.23</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.mashape.unirest</groupId> | ||||
|             <artifactId>unirest-java</artifactId> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.beust</groupId> | ||||
|             <artifactId>jcommander</artifactId> | ||||
|             <version>${jcommander.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.apache.spark</groupId> | ||||
|             <artifactId>spark-core_2.11</artifactId> | ||||
|             <version>${spark.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,352 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.server; | ||||
| 
 | ||||
| import com.beust.jcommander.JCommander; | ||||
| import com.beust.jcommander.ParameterException; | ||||
| import lombok.Data; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.CSVSparkTransform; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| import play.BuiltInComponents; | ||||
| import play.Mode; | ||||
| import play.routing.Router; | ||||
| import play.routing.RoutingDsl; | ||||
| import play.server.Server; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.Base64; | ||||
| import java.util.Random; | ||||
| 
 | ||||
| import static play.mvc.Results.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Data | ||||
| public class CSVSparkTransformServer extends SparkTransformServer { | ||||
|     private CSVSparkTransform transform; | ||||
| 
 | ||||
|     public void runMain(String[] args) throws Exception { | ||||
|         JCommander jcmdr = new JCommander(this); | ||||
| 
 | ||||
|         try { | ||||
|             jcmdr.parse(args); | ||||
|         } catch (ParameterException e) { | ||||
|             //User provides invalid input -> print the usage info | ||||
|             jcmdr.usage(); | ||||
|             if (jsonPath == null) | ||||
|                 System.err.println("Json path parameter is missing."); | ||||
|             try { | ||||
|                 Thread.sleep(500); | ||||
|             } catch (Exception e2) { | ||||
|             } | ||||
|             System.exit(1); | ||||
|         } | ||||
| 
 | ||||
|         if (jsonPath != null) { | ||||
|             String json = FileUtils.readFileToString(new File(jsonPath)); | ||||
|             TransformProcess transformProcess = TransformProcess.fromJson(json); | ||||
|             transform = new CSVSparkTransform(transformProcess); | ||||
|         } else { | ||||
|             log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" | ||||
|                     + "to /transformprocess"); | ||||
|         } | ||||
| 
 | ||||
|         //Set play secret key, if required | ||||
|         //http://www.playframework.com/documentation/latest/ApplicationSecret | ||||
|         String crypto = System.getProperty("play.crypto.secret"); | ||||
|         if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { | ||||
|             byte[] newCrypto = new byte[1024]; | ||||
| 
 | ||||
|             new Random().nextBytes(newCrypto); | ||||
| 
 | ||||
|             String base64 = Base64.getEncoder().encodeToString(newCrypto); | ||||
|             System.setProperty("play.crypto.secret", base64); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         server = Server.forRouter(Mode.PROD, port, this::createRouter); | ||||
|     } | ||||
| 
 | ||||
|     protected Router createRouter(BuiltInComponents b){ | ||||
|         RoutingDsl routingDsl = RoutingDsl.fromComponents(b); | ||||
| 
 | ||||
|         routingDsl.GET("/transformprocess").routingTo(req -> { | ||||
|             try { | ||||
|                 if (transform == null) | ||||
|                     return badRequest(); | ||||
|                 return ok(transform.getTransformProcess().toJson()).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("Error in GET /transformprocess",e); | ||||
|                 return internalServerError(e.getMessage()); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformprocess").routingTo(req -> { | ||||
|             try { | ||||
|                 TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); | ||||
|                 setCSVTransformProcess(transformProcess); | ||||
|                 log.info("Transform process initialized"); | ||||
|                 return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("Error in POST /transformprocess",e); | ||||
|                 return internalServerError(e.getMessage()); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformincremental").routingTo(req -> { | ||||
|             if (isSequence(req)) { | ||||
|                 try { | ||||
|                     BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); | ||||
|                     if (record == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformincremental", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); | ||||
|                     if (record == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformincremental", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transform").routingTo(req -> { | ||||
|             if (isSequence(req)) { | ||||
|                 try { | ||||
|                     SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); | ||||
|                     if (batch == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(batch)).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transform", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); | ||||
|                     BatchCSVRecord batch = transform(input); | ||||
|                     if (batch == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(batch)).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transform", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformincrementalarray").routingTo(req -> { | ||||
|             if (isSequence(req)) { | ||||
|                 try { | ||||
|                     BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); | ||||
|                     if (record == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformincrementalarray", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); | ||||
|                     if (record == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformincrementalarray", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformarray").routingTo(req -> { | ||||
|             if (isSequence(req)) { | ||||
|                 try { | ||||
|                     SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); | ||||
|                     if (batchCSVRecord == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformarray", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } else { | ||||
|                 try { | ||||
|                     BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); | ||||
|                     if (batchCSVRecord == null) | ||||
|                         return badRequest(); | ||||
|                     return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); | ||||
|                 } catch (Exception e) { | ||||
|                     log.error("Error in /transformarray", e); | ||||
|                     return internalServerError(e.getMessage()); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         return routingDsl.build(); | ||||
|     } | ||||
| 
 | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         new CSVSparkTransformServer().runMain(args); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transformProcess | ||||
|      */ | ||||
|     @Override | ||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { | ||||
|         this.transform = new CSVSparkTransform(transformProcess); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { | ||||
|         log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass()); | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public TransformProcess getCSVTransformProcess() { | ||||
|         return transform.getTransformProcess(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ImageTransformProcess getImageTransformProcess() { | ||||
|         log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass()); | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      */ | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { | ||||
|         return this.transform.transformSequenceIncremental(transform); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         return transform.transformSequence(batchCSVRecord); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         return this.transform.transformSequenceArray(batchCSVRecord); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { | ||||
|         return this.transform.transformSequenceArrayIncremental(singleCsvRecord); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param transform | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { | ||||
|         return this.transform.transform(transform); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         return this.transform.transform(batchCSVRecord); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { | ||||
|         return transform.transform(batchCSVRecord); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param batchCSVRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { | ||||
|         try { | ||||
|             return this.transform.toArray(batchCSVRecord); | ||||
|         } catch (IOException e) { | ||||
|             log.error("Error in transformArray",e); | ||||
|             throw new IllegalStateException("Transform array shouldn't throw exception"); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @param singleCsvRecord | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { | ||||
|         try { | ||||
|             return this.transform.toArray(singleCsvRecord); | ||||
|         } catch (IOException e) { | ||||
|             log.error("Error in transformArrayIncremental",e); | ||||
|             throw new IllegalStateException("Transform array shouldn't throw exception"); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { | ||||
|         log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass()); | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { | ||||
|         log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass()); | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| } | ||||
| @ -1,261 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.server; | ||||
| 
 | ||||
| import com.beust.jcommander.JCommander; | ||||
| import com.beust.jcommander.ParameterException; | ||||
| import lombok.Data; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.model.ImageSparkTransform; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| import play.BuiltInComponents; | ||||
| import play.Mode; | ||||
| import play.libs.Files; | ||||
| import play.mvc.Http; | ||||
| import play.routing.Router; | ||||
| import play.routing.RoutingDsl; | ||||
| import play.server.Server; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| import static play.mvc.Results.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Data | ||||
| public class ImageSparkTransformServer extends SparkTransformServer { | ||||
|     private ImageSparkTransform transform; | ||||
| 
 | ||||
|     public void runMain(String[] args) throws Exception { | ||||
|         JCommander jcmdr = new JCommander(this); | ||||
| 
 | ||||
|         try { | ||||
|             jcmdr.parse(args); | ||||
|         } catch (ParameterException e) { | ||||
|             //User provides invalid input -> print the usage info | ||||
|             jcmdr.usage(); | ||||
|             if (jsonPath == null) | ||||
|                 System.err.println("Json path parameter is missing."); | ||||
|             try { | ||||
|                 Thread.sleep(500); | ||||
|             } catch (Exception e2) { | ||||
|             } | ||||
|             System.exit(1); | ||||
|         } | ||||
| 
 | ||||
|         if (jsonPath != null) { | ||||
|             String json = FileUtils.readFileToString(new File(jsonPath)); | ||||
|             ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); | ||||
|             transform = new ImageSparkTransform(transformProcess); | ||||
|         } else { | ||||
|             log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" | ||||
|                     + "to /transformprocess"); | ||||
|         } | ||||
| 
 | ||||
|         server = Server.forRouter(Mode.PROD, port, this::createRouter); | ||||
|     } | ||||
| 
 | ||||
|     protected Router createRouter(BuiltInComponents builtInComponents){ | ||||
|         RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); | ||||
| 
 | ||||
|         routingDsl.GET("/transformprocess").routingTo(req -> { | ||||
|             try { | ||||
|                 if (transform == null) | ||||
|                     return badRequest(); | ||||
|                 log.info("Transform process initialized"); | ||||
|                 return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformprocess").routingTo(req -> { | ||||
|             try { | ||||
|                 ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); | ||||
|                 setImageTransformProcess(transformProcess); | ||||
|                 log.info("Transform process initialized"); | ||||
|                 return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformincrementalarray").routingTo(req -> { | ||||
|             try { | ||||
|                 SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); | ||||
|                 if (record == null) | ||||
|                     return badRequest(); | ||||
|                 return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformincrementalimage").routingTo(req -> { | ||||
|             try { | ||||
|                 Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData(); | ||||
|                 List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles(); | ||||
|                 if (files.isEmpty() || files.get(0).getRef() == null ) { | ||||
|                     return badRequest(); | ||||
|                 } | ||||
| 
 | ||||
|                 File file = files.get(0).getRef().path().toFile(); | ||||
|                 SingleImageRecord record = new SingleImageRecord(file.toURI()); | ||||
| 
 | ||||
|                 return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformarray").routingTo(req -> { | ||||
|             try { | ||||
|                 BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); | ||||
|                 if (batch == null) | ||||
|                     return badRequest(); | ||||
|                 return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         routingDsl.POST("/transformimage").routingTo(req -> { | ||||
|             try { | ||||
|                 Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData(); | ||||
|                 List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles(); | ||||
|                 if (files.size() == 0) { | ||||
|                     return badRequest(); | ||||
|                 } | ||||
| 
 | ||||
|                 List<SingleImageRecord> records = new ArrayList<>(); | ||||
| 
 | ||||
|                 for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) { | ||||
|                     Files.TemporaryFile file = filePart.getRef(); | ||||
|                     if (file != null) { | ||||
|                         SingleImageRecord record = new SingleImageRecord(file.path().toUri()); | ||||
|                         records.add(record); | ||||
|                     } | ||||
|                 } | ||||
| 
 | ||||
|                 BatchImageRecord batch = new BatchImageRecord(records); | ||||
| 
 | ||||
|                 return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); | ||||
|             } catch (Exception e) { | ||||
|                 log.error("",e); | ||||
|                 return internalServerError(); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         return routingDsl.build(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { | ||||
|         throw new UnsupportedOperationException(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         throw new UnsupportedOperationException(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         throw new UnsupportedOperationException(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { | ||||
|         throw new UnsupportedOperationException(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { | ||||
|         this.transform = new ImageSparkTransform(imageTransformProcess); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public TransformProcess getCSVTransformProcess() { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public ImageTransformProcess getImageTransformProcess() { | ||||
|         return transform.getImageTransformProcess(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { | ||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException { | ||||
|         return transform.toArray(record); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException { | ||||
|         return transform.toArray(batch); | ||||
|     } | ||||
| 
 | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         new ImageSparkTransformServer().runMain(args); | ||||
|     } | ||||
| } | ||||
| @ -1,67 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.server; | ||||
| 
 | ||||
| import com.beust.jcommander.Parameter; | ||||
| import com.fasterxml.jackson.databind.JsonNode; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; | ||||
| import org.nd4j.shade.jackson.databind.ObjectMapper; | ||||
| import play.mvc.Http; | ||||
| import play.server.Server; | ||||
| 
 | ||||
| public abstract class SparkTransformServer implements DataVecTransformService { | ||||
|     @Parameter(names = {"-j", "--jsonPath"}, arity = 1) | ||||
|     protected String jsonPath = null; | ||||
|     @Parameter(names = {"-dp", "--dataVecPort"}, arity = 1) | ||||
|     protected int port = 9000; | ||||
|     @Parameter(names = {"-dt", "--dataType"}, arity = 1) | ||||
|     private TransformDataType transformDataType = null; | ||||
|     protected Server server; | ||||
|     protected static ObjectMapper objectMapper = new ObjectMapper(); | ||||
|     protected static String contentType = "application/json"; | ||||
| 
 | ||||
|     public abstract void runMain(String[] args) throws Exception; | ||||
| 
 | ||||
|     /** | ||||
|      * Stop the server | ||||
|      */ | ||||
|     public void stop() { | ||||
|         if (server != null) | ||||
|             server.stop(); | ||||
|     } | ||||
| 
 | ||||
|     protected boolean isSequence(Http.Request request) { | ||||
|         return request.hasHeader(SEQUENCE_OR_NOT_HEADER) | ||||
|                 && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); | ||||
|     } | ||||
| 
 | ||||
|     protected String getJsonText(Http.Request request) { | ||||
|         JsonNode tryJson = request.body().asJson(); | ||||
|         if (tryJson != null) | ||||
|             return tryJson.toString(); | ||||
|         else | ||||
|             return request.body().asText(); | ||||
|     } | ||||
| 
 | ||||
|     public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); | ||||
| } | ||||
| @ -1,76 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.server; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| 
 | ||||
| import java.io.InvalidClassException; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @Slf4j | ||||
| public class SparkTransformServerChooser { | ||||
|     private SparkTransformServer sparkTransformServer = null; | ||||
|     private TransformDataType transformDataType = null; | ||||
| 
 | ||||
|     public void runMain(String[] args) throws Exception { | ||||
| 
 | ||||
|         int pos = getMatchingPosition(args, "-dt", "--dataType"); | ||||
|         if (pos == -1) { | ||||
|             log.error("no valid options"); | ||||
|             log.error("-dt, --dataType   Options: [CSV, IMAGE]"); | ||||
|             throw new Exception("no valid options"); | ||||
|         } else { | ||||
|             transformDataType = TransformDataType.valueOf(args[pos + 1]); | ||||
|         } | ||||
| 
 | ||||
|         switch (transformDataType) { | ||||
|             case CSV: | ||||
|                 sparkTransformServer = new CSVSparkTransformServer(); | ||||
|                 break; | ||||
|             case IMAGE: | ||||
|                 sparkTransformServer = new ImageSparkTransformServer(); | ||||
|                 break; | ||||
|             default: | ||||
|                 throw new InvalidClassException("no matching SparkTransform class"); | ||||
|         } | ||||
| 
 | ||||
|         sparkTransformServer.runMain(args); | ||||
|     } | ||||
| 
 | ||||
|     private int getMatchingPosition(String[] args, String... options) { | ||||
|         List optionList = Arrays.asList(options); | ||||
| 
 | ||||
|         for (int i = 0; i < args.length; i++) { | ||||
|             if (optionList.contains(args[i])) { | ||||
|                 return i; | ||||
|             } | ||||
|         } | ||||
|         return -1; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         new SparkTransformServerChooser().runMain(args); | ||||
|     } | ||||
| } | ||||
| @ -1,25 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.inference.server; | ||||
| 
 | ||||
| public enum TransformDataType { | ||||
|     CSV, IMAGE, | ||||
| } | ||||
| @ -1,350 +0,0 @@ | ||||
| # This is the main configuration file for the application. | ||||
| # https://www.playframework.com/documentation/latest/ConfigFile | ||||
| # ~~~~~ | ||||
| # Play uses HOCON as its configuration file format.  HOCON has a number | ||||
| # of advantages over other config formats, but there are two things that | ||||
| # can be used when modifying settings. | ||||
| # | ||||
| # You can include other configuration files in this main application.conf file: | ||||
| #include "extra-config.conf" | ||||
| # | ||||
| # You can declare variables and substitute for them: | ||||
| #mykey = ${some.value} | ||||
| # | ||||
| # And if an environment variable exists when there is no other subsitution, then | ||||
| # HOCON will fall back to substituting environment variable: | ||||
| #mykey = ${JAVA_HOME} | ||||
| 
 | ||||
| ## Akka | ||||
| # https://www.playframework.com/documentation/latest/ScalaAkka#Configuration | ||||
| # https://www.playframework.com/documentation/latest/JavaAkka#Configuration | ||||
| # ~~~~~ | ||||
| # Play uses Akka internally and exposes Akka Streams and actors in Websockets and | ||||
| # other streaming HTTP responses. | ||||
| akka { | ||||
|   # "akka.log-config-on-start" is extraordinarly useful because it log the complete | ||||
|   # configuration at INFO level, including defaults and overrides, so it s worth | ||||
|   # putting at the very top. | ||||
|   # | ||||
|   # Put the following in your conf/logback.xml file: | ||||
|   # | ||||
|   # <logger name="akka.actor" level="INFO" /> | ||||
|   # | ||||
|   # And then uncomment this line to debug the configuration. | ||||
|   # | ||||
|   #log-config-on-start = true | ||||
| } | ||||
| 
 | ||||
| ## Modules | ||||
| # https://www.playframework.com/documentation/latest/Modules | ||||
| # ~~~~~ | ||||
| # Control which modules are loaded when Play starts.  Note that modules are | ||||
| # the replacement for "GlobalSettings", which are deprecated in 2.5.x. | ||||
| # Please see https://www.playframework.com/documentation/latest/GlobalSettings | ||||
| # for more information. | ||||
| # | ||||
| # You can also extend Play functionality by using one of the publically available | ||||
| # Play modules: https://playframework.com/documentation/latest/ModuleDirectory | ||||
| play.modules { | ||||
|   # By default, Play will load any class called Module that is defined | ||||
|   # in the root package (the "app" directory), or you can define them | ||||
|   # explicitly below. | ||||
|   # If there are any built-in modules that you want to disable, you can list them here. | ||||
|   #enabled += my.application.Module | ||||
| 
 | ||||
|   # If there are any built-in modules that you want to disable, you can list them here. | ||||
|   #disabled += "" | ||||
| } | ||||
| 
 | ||||
| ## Internationalisation | ||||
| # https://www.playframework.com/documentation/latest/JavaI18N | ||||
| # https://www.playframework.com/documentation/latest/ScalaI18N | ||||
| # ~~~~~ | ||||
| # Play comes with its own i18n settings, which allow the user's preferred language | ||||
| # to map through to internal messages, or allow the language to be stored in a cookie. | ||||
| play.i18n { | ||||
|   # The application languages | ||||
|   langs = [ "en" ] | ||||
| 
 | ||||
|   # Whether the language cookie should be secure or not | ||||
|   #langCookieSecure = true | ||||
| 
 | ||||
|   # Whether the HTTP only attribute of the cookie should be set to true | ||||
|   #langCookieHttpOnly = true | ||||
| } | ||||
| 
 | ||||
| ## Play HTTP settings | ||||
| # ~~~~~ | ||||
| play.http { | ||||
|   ## Router | ||||
|   # https://www.playframework.com/documentation/latest/JavaRouting | ||||
|   # https://www.playframework.com/documentation/latest/ScalaRouting | ||||
|   # ~~~~~ | ||||
|   # Define the Router object to use for this application. | ||||
|   # This router will be looked up first when the application is starting up, | ||||
|   # so make sure this is the entry point. | ||||
|   # Furthermore, it's assumed your route file is named properly. | ||||
|   # So for an application router like `my.application.Router`, | ||||
|   # you may need to define a router file `conf/my.application.routes`. | ||||
|   # Default to Routes in the root package (aka "apps" folder) (and conf/routes) | ||||
|   #router = my.application.Router | ||||
| 
 | ||||
|   ## Action Creator | ||||
|   # https://www.playframework.com/documentation/latest/JavaActionCreator | ||||
|   # ~~~~~ | ||||
|   #actionCreator = null | ||||
| 
 | ||||
|   ## ErrorHandler | ||||
|   # https://www.playframework.com/documentation/latest/JavaRouting | ||||
|   # https://www.playframework.com/documentation/latest/ScalaRouting | ||||
|   # ~~~~~ | ||||
|   # If null, will attempt to load a class called ErrorHandler in the root package, | ||||
|   #errorHandler = null | ||||
| 
 | ||||
|   ## Filters | ||||
|   # https://www.playframework.com/documentation/latest/ScalaHttpFilters | ||||
|   # https://www.playframework.com/documentation/latest/JavaHttpFilters | ||||
|   # ~~~~~ | ||||
|   # Filters run code on every request. They can be used to perform | ||||
|   # common logic for all your actions, e.g. adding common headers. | ||||
|   # Defaults to "Filters" in the root package (aka "apps" folder) | ||||
|   # Alternatively you can explicitly register a class here. | ||||
|   #filters += my.application.Filters | ||||
| 
 | ||||
|   ## Session & Flash | ||||
|   # https://www.playframework.com/documentation/latest/JavaSessionFlash | ||||
|   # https://www.playframework.com/documentation/latest/ScalaSessionFlash | ||||
|   # ~~~~~ | ||||
|   session { | ||||
|     # Sets the cookie to be sent only over HTTPS. | ||||
|     #secure = true | ||||
| 
 | ||||
|     # Sets the cookie to be accessed only by the server. | ||||
|     #httpOnly = true | ||||
| 
 | ||||
|     # Sets the max-age field of the cookie to 5 minutes. | ||||
|     # NOTE: this only sets when the browser will discard the cookie. Play will consider any | ||||
|     # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, | ||||
|     # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. | ||||
|     #maxAge = 300 | ||||
| 
 | ||||
|     # Sets the domain on the session cookie. | ||||
|     #domain = "example.com" | ||||
|   } | ||||
| 
 | ||||
|   flash { | ||||
|     # Sets the cookie to be sent only over HTTPS. | ||||
|     #secure = true | ||||
| 
 | ||||
|     # Sets the cookie to be accessed only by the server. | ||||
|     #httpOnly = true | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ## Netty Provider | ||||
| # https://www.playframework.com/documentation/latest/SettingsNetty | ||||
| # ~~~~~ | ||||
| play.server.netty { | ||||
|   # Whether the Netty wire should be logged | ||||
|   #log.wire = true | ||||
| 
 | ||||
|   # If you run Play on Linux, you can use Netty's native socket transport | ||||
|   # for higher performance with less garbage. | ||||
|   #transport = "native" | ||||
| } | ||||
| 
 | ||||
| ## WS (HTTP Client) | ||||
| # https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS | ||||
| # ~~~~~ | ||||
| # The HTTP client primarily used for REST APIs.  The default client can be | ||||
| # configured directly, but you can also create different client instances | ||||
| # with customized settings. You must enable this by adding to build.sbt: | ||||
| # | ||||
| # libraryDependencies += ws // or javaWs if using java | ||||
| # | ||||
| play.ws { | ||||
|   # Sets HTTP requests not to follow 302 requests | ||||
|   #followRedirects = false | ||||
| 
 | ||||
|   # Sets the maximum number of open HTTP connections for the client. | ||||
|   #ahc.maxConnectionsTotal = 50 | ||||
| 
 | ||||
|   ## WS SSL | ||||
|   # https://www.playframework.com/documentation/latest/WsSSL | ||||
|   # ~~~~~ | ||||
|   ssl { | ||||
|     # Configuring HTTPS with Play WS does not require programming.  You can | ||||
|     # set up both trustManager and keyManager for mutual authentication, and | ||||
|     # turn on JSSE debugging in development with a reload. | ||||
|     #debug.handshake = true | ||||
|     #trustManager = { | ||||
|     #  stores = [ | ||||
|     #    { type = "JKS", path = "exampletrust.jks" } | ||||
|     #  ] | ||||
|     #} | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ## Cache | ||||
| # https://www.playframework.com/documentation/latest/JavaCache | ||||
| # https://www.playframework.com/documentation/latest/ScalaCache | ||||
| # ~~~~~ | ||||
| # Play comes with an integrated cache API that can reduce the operational | ||||
| # overhead of repeated requests. You must enable this by adding to build.sbt: | ||||
| # | ||||
| # libraryDependencies += cache | ||||
| # | ||||
| play.cache { | ||||
|   # If you want to bind several caches, you can bind the individually | ||||
|   #bindCaches = ["db-cache", "user-cache", "session-cache"] | ||||
| } | ||||
| 
 | ||||
| ## Filters | ||||
| # https://www.playframework.com/documentation/latest/Filters | ||||
| # ~~~~~ | ||||
| # There are a number of built-in filters that can be enabled and configured | ||||
| # to give Play greater security.  You must enable this by adding to build.sbt: | ||||
| # | ||||
| # libraryDependencies += filters | ||||
| # | ||||
| play.filters { | ||||
|   ## CORS filter configuration | ||||
|   # https://www.playframework.com/documentation/latest/CorsFilter | ||||
|   # ~~~~~ | ||||
|   # CORS is a protocol that allows web applications to make requests from the browser | ||||
|   # across different domains. | ||||
|   # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has | ||||
|   # dependencies on CORS settings. | ||||
|   cors { | ||||
|     # Filter paths by a whitelist of path prefixes | ||||
|     #pathPrefixes = ["/some/path", ...] | ||||
| 
 | ||||
|     # The allowed origins. If null, all origins are allowed. | ||||
|     #allowedOrigins = ["http://www.example.com"] | ||||
| 
 | ||||
|     # The allowed HTTP methods. If null, all methods are allowed | ||||
|     #allowedHttpMethods = ["GET", "POST"] | ||||
|   } | ||||
| 
 | ||||
|   ## CSRF Filter | ||||
|   # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter | ||||
|   # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter | ||||
|   # ~~~~~ | ||||
|   # Play supports multiple methods for verifying that a request is not a CSRF request. | ||||
|   # The primary mechanism is a CSRF token. This token gets placed either in the query string | ||||
|   # or body of every form submitted, and also gets placed in the users session. | ||||
|   # Play then verifies that both tokens are present and match. | ||||
|   csrf { | ||||
|     # Sets the cookie to be sent only over HTTPS | ||||
|     #cookie.secure = true | ||||
| 
 | ||||
|     # Defaults to CSRFErrorHandler in the root package. | ||||
|     #errorHandler = MyCSRFErrorHandler | ||||
|   } | ||||
| 
 | ||||
|   ## Security headers filter configuration | ||||
|   # https://www.playframework.com/documentation/latest/SecurityHeaders | ||||
|   # ~~~~~ | ||||
|   # Defines security headers that prevent XSS attacks. | ||||
|   # If enabled, then all options are set to the below configuration by default: | ||||
|   headers { | ||||
|     # The X-Frame-Options header. If null, the header is not set. | ||||
|     #frameOptions = "DENY" | ||||
| 
 | ||||
|     # The X-XSS-Protection header. If null, the header is not set. | ||||
|     #xssProtection = "1; mode=block" | ||||
| 
 | ||||
|     # The X-Content-Type-Options header. If null, the header is not set. | ||||
|     #contentTypeOptions = "nosniff" | ||||
| 
 | ||||
|     # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. | ||||
|     #permittedCrossDomainPolicies = "master-only" | ||||
| 
 | ||||
|     # The Content-Security-Policy header. If null, the header is not set. | ||||
|     #contentSecurityPolicy = "default-src 'self'" | ||||
|   } | ||||
| 
 | ||||
|   ## Allowed hosts filter configuration | ||||
|   # https://www.playframework.com/documentation/latest/AllowedHostsFilter | ||||
|   # ~~~~~ | ||||
|   # Play provides a filter that lets you configure which hosts can access your application. | ||||
|   # This is useful to prevent cache poisoning attacks. | ||||
|   hosts { | ||||
|     # Allow requests to example.com, its subdomains, and localhost:9000. | ||||
|     #allowed = [".example.com", "localhost:9000"] | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ## Evolutions | ||||
| # https://www.playframework.com/documentation/latest/Evolutions | ||||
| # ~~~~~ | ||||
| # Evolutions allows database scripts to be automatically run on startup in dev mode | ||||
| # for database migrations. You must enable this by adding to build.sbt: | ||||
| # | ||||
| # libraryDependencies += evolutions | ||||
| # | ||||
| play.evolutions { | ||||
|   # You can disable evolutions for a specific datasource if necessary | ||||
|   #db.default.enabled = false | ||||
| } | ||||
| 
 | ||||
| ## Database Connection Pool | ||||
| # https://www.playframework.com/documentation/latest/SettingsJDBC | ||||
| # ~~~~~ | ||||
| # Play doesn't require a JDBC database to run, but you can easily enable one. | ||||
| # | ||||
| # libraryDependencies += jdbc | ||||
| # | ||||
| play.db { | ||||
|   # The combination of these two settings results in "db.default" as the | ||||
|   # default JDBC pool: | ||||
|   #config = "db" | ||||
|   #default = "default" | ||||
| 
 | ||||
|   # Play uses HikariCP as the default connection pool.  You can override | ||||
|   # settings by changing the prototype: | ||||
|   prototype { | ||||
|     # Sets a fixed JDBC connection pool size of 50 | ||||
|     #hikaricp.minimumIdle = 50 | ||||
|     #hikaricp.maximumPoolSize = 50 | ||||
|   } | ||||
| } | ||||
| 
 | ||||
| ## JDBC Datasource | ||||
| # https://www.playframework.com/documentation/latest/JavaDatabase | ||||
| # https://www.playframework.com/documentation/latest/ScalaDatabase | ||||
| # ~~~~~ | ||||
| # Once JDBC datasource is set up, you can work with several different | ||||
| # database options: | ||||
| # | ||||
| # Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick | ||||
| # JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA | ||||
| # EBean: https://playframework.com/documentation/latest/JavaEbean | ||||
| # Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm | ||||
| # | ||||
| db { | ||||
|   # You can declare as many datasources as you want. | ||||
|   # By convention, the default datasource is named `default` | ||||
| 
 | ||||
|   # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database | ||||
|   default.driver = org.h2.Driver | ||||
|   default.url = "jdbc:h2:mem:play" | ||||
|   #default.username = sa | ||||
|   #default.password = "" | ||||
| 
 | ||||
|   # You can expose this datasource via JNDI if needed (Useful for JPA) | ||||
|   default.jndiName=DefaultDS | ||||
| 
 | ||||
|   # You can turn on SQL logging for any datasource | ||||
|   # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements | ||||
|   #default.logSql=true | ||||
| } | ||||
| 
 | ||||
| jpa.default=defaultPersistenceUnit | ||||
| 
 | ||||
| 
 | ||||
| #Increase default maximum post length - used for remote listener functionality | ||||
| #Can get response 413 with larger networks without setting this | ||||
| # parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead | ||||
| #parsers.text.maxLength=10M | ||||
| play.http.parser.maxMemoryBuffer=10M | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; | ||||
| import org.nd4j.common.tests.BaseND4JTest; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { | ||||
| 
 | ||||
|     @Override | ||||
|     protected Set<Class<?>> getExclusions() { | ||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) | ||||
| 	    return new HashSet<>(); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
| 	protected String getPackageName() { | ||||
|     	return "org.datavec.spark.transform"; | ||||
| 	} | ||||
| 
 | ||||
| 	@Override | ||||
| 	protected Class<?> getBaseClass() { | ||||
|     	return BaseND4JTest.class; | ||||
| 	} | ||||
| } | ||||
| @ -1,127 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| import com.mashape.unirest.http.JsonNode; | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static org.junit.Assert.assertTrue; | ||||
| import static org.junit.Assume.assumeNotNull; | ||||
| 
 | ||||
| public class CSVSparkTransformServerNoJsonTest { | ||||
| 
 | ||||
|     private static CSVSparkTransformServer server; | ||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|     private static TransformProcess transformProcess = | ||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); | ||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void before() throws Exception { | ||||
|         server = new CSVSparkTransformServer(); | ||||
|         FileUtils.write(fileSave, transformProcess.toJson()); | ||||
| 
 | ||||
|         // Only one time | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (Exception e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         server.runMain(new String[] {"-dp", "9050"}); | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void after() throws Exception { | ||||
|         fileSave.delete(); | ||||
|         server.stop(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testServer() throws Exception { | ||||
|         assertTrue(server.getTransform() == null); | ||||
|         JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(transformProcess.toJson()).asJson().getBody(); | ||||
|         assumeNotNull(server.getTransform()); | ||||
| 
 | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = new SingleCSVRecord(values); | ||||
|         JsonNode jsonNode = | ||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") | ||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); | ||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(SingleCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < 3; i++) | ||||
|             batchCSVRecord.add(singleCsvRecord); | ||||
|     /*    BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| */ | ||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,121 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| 
 | ||||
| import com.mashape.unirest.http.JsonNode; | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| public class CSVSparkTransformServerTest { | ||||
| 
 | ||||
|     private static CSVSparkTransformServer server; | ||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|     private static TransformProcess transformProcess = | ||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); | ||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void before() throws Exception { | ||||
|         server = new CSVSparkTransformServer(); | ||||
|         FileUtils.write(fileSave, transformProcess.toJson()); | ||||
|         // Only one time | ||||
| 
 | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (Exception e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void after() throws Exception { | ||||
|         fileSave.deleteOnExit(); | ||||
|         server.stop(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testServer() throws Exception { | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = new SingleCSVRecord(values); | ||||
|         JsonNode jsonNode = | ||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") | ||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); | ||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(SingleCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < 3; i++) | ||||
|             batchCSVRecord.add(singleCsvRecord); | ||||
|         BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,164 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| 
 | ||||
| import com.mashape.unirest.http.JsonNode; | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.server.ImageSparkTransformServer; | ||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; | ||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; | ||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class ImageSparkTransformServerTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     private static ImageSparkTransformServer server; | ||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void before() throws Exception { | ||||
|         server = new ImageSparkTransformServer(); | ||||
| 
 | ||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) | ||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); | ||||
| 
 | ||||
|         FileUtils.write(fileSave, imgTransformProcess.toJson()); | ||||
| 
 | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (Exception e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"}); | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void after() throws Exception { | ||||
|         fileSave.deleteOnExit(); | ||||
|         server.stop(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testImageServer() throws Exception { | ||||
|         SingleImageRecord record = | ||||
|                         new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); | ||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asJson().getBody(); | ||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         BatchImageRecord batch = new BatchImageRecord(); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); | ||||
| 
 | ||||
|         JsonNode jsonNodeBatch = | ||||
|                         Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") | ||||
|                                         .header("Content-Type", "application/json").body(batch).asJson().getBody(); | ||||
|         Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(batch) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         INDArray result = getNDArray(jsonNode); | ||||
|         assertEquals(1, result.size(0)); | ||||
| 
 | ||||
|         INDArray batchResult = getNDArray(jsonNodeBatch); | ||||
|         assertEquals(3, batchResult.size(0)); | ||||
| 
 | ||||
| //        System.out.println(array); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testImageServerMultipart() throws Exception { | ||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") | ||||
|                 .header("accept", "application/json") | ||||
|                 .field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile()) | ||||
|                 .field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile()) | ||||
|                 .field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile()) | ||||
|                 .asJson().getBody(); | ||||
| 
 | ||||
| 
 | ||||
|         INDArray batchResult = getNDArray(jsonNode); | ||||
|         assertEquals(3, batchResult.size(0)); | ||||
| 
 | ||||
| //        System.out.println(batchResult); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testImageServerSingleMultipart() throws Exception { | ||||
|         File f = testDir.newFolder(); | ||||
|         File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f); | ||||
| 
 | ||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") | ||||
|                 .header("accept", "application/json") | ||||
|                 .field("file1", imgFile) | ||||
|                 .asJson().getBody(); | ||||
| 
 | ||||
| 
 | ||||
|         INDArray result = getNDArray(jsonNode); | ||||
|         assertEquals(1, result.size(0)); | ||||
| 
 | ||||
| //        System.out.println(result); | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getNDArray(JsonNode node) throws IOException { | ||||
|         return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); | ||||
|     } | ||||
| } | ||||
| @ -1,168 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.datavec.spark.transform; | ||||
| 
 | ||||
| 
 | ||||
| import com.mashape.unirest.http.JsonNode; | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.datavec.api.transform.TransformProcess; | ||||
| import org.datavec.api.transform.schema.Schema; | ||||
| import org.datavec.image.transform.ImageTransformProcess; | ||||
| import org.datavec.spark.inference.server.SparkTransformServerChooser; | ||||
| import org.datavec.spark.inference.server.TransformDataType; | ||||
| import org.datavec.spark.inference.model.model.*; | ||||
| import org.junit.AfterClass; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.io.ClassPathResource; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class SparkTransformServerTest { | ||||
|     private static SparkTransformServerChooser serverChooser; | ||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); | ||||
|     private static TransformProcess transformProcess = | ||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble(    "2.0").build(); | ||||
| 
 | ||||
|     private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json"); | ||||
|     private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json"); | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void before() throws Exception { | ||||
|         serverChooser = new SparkTransformServerChooser(); | ||||
| 
 | ||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) | ||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); | ||||
| 
 | ||||
|         FileUtils.write(imageTransformFile, imgTransformProcess.toJson()); | ||||
| 
 | ||||
|         FileUtils.write(csvTransformFile, transformProcess.toJson()); | ||||
| 
 | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (Exception e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @AfterClass | ||||
|     public static void after() throws Exception { | ||||
|         imageTransformFile.deleteOnExit(); | ||||
|         csvTransformFile.deleteOnExit(); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testImageServer() throws Exception { | ||||
|         serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt", | ||||
|                         TransformDataType.IMAGE.toString()}); | ||||
| 
 | ||||
|         SingleImageRecord record = | ||||
|                         new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); | ||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asJson().getBody(); | ||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         BatchImageRecord batch = new BatchImageRecord(); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); | ||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); | ||||
| 
 | ||||
|         JsonNode jsonNodeBatch = | ||||
|                         Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") | ||||
|                                         .header("Content-Type", "application/json").body(batch).asJson().getBody(); | ||||
|         Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(batch) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         INDArray result = getNDArray(jsonNode); | ||||
|         assertEquals(1, result.size(0)); | ||||
| 
 | ||||
|         INDArray batchResult = getNDArray(jsonNodeBatch); | ||||
|         assertEquals(3, batchResult.size(0)); | ||||
| 
 | ||||
|         serverChooser.getSparkTransformServer().stop(); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testCSVServer() throws Exception { | ||||
|         serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt", | ||||
|                         TransformDataType.CSV.toString()}); | ||||
| 
 | ||||
|         String[] values = new String[] {"1.0", "2.0"}; | ||||
|         SingleCSVRecord record = new SingleCSVRecord(values); | ||||
|         JsonNode jsonNode = | ||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") | ||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); | ||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(SingleCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); | ||||
|         for (int i = 0; i < 3; i++) | ||||
|             batchCSVRecord.add(singleCsvRecord); | ||||
|         BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); | ||||
| 
 | ||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) | ||||
|                         .asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") | ||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") | ||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); | ||||
| 
 | ||||
| 
 | ||||
|         serverChooser.getSparkTransformServer().stop(); | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getNDArray(JsonNode node) throws IOException { | ||||
|         return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); | ||||
|     } | ||||
| } | ||||
| @ -1,6 +0,0 @@ | ||||
| play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule | ||||
| play.modules.enabled += io.skymind.skil.service.PredictionModule | ||||
| play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk | ||||
| play.server.pidfile.path=/tmp/RUNNING_PID | ||||
| 
 | ||||
| play.server.http.port = 9600 | ||||
| @ -1,68 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.datavec</groupId> | ||||
|         <artifactId>datavec-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>datavec-spark-inference-parent</artifactId> | ||||
|     <packaging>pom</packaging> | ||||
| 
 | ||||
|     <name>datavec-spark-inference-parent</name> | ||||
| 
 | ||||
|     <modules> | ||||
|         <module>datavec-spark-inference-server</module> | ||||
|         <module>datavec-spark-inference-client</module> | ||||
|         <module>datavec-spark-inference-model</module> | ||||
|     </modules> | ||||
| 
 | ||||
|     <dependencyManagement> | ||||
|         <dependencies> | ||||
|             <dependency> | ||||
|                 <groupId>org.datavec</groupId> | ||||
|                 <artifactId>datavec-data-image</artifactId> | ||||
|                 <version>${datavec.version}</version> | ||||
|             </dependency> | ||||
|             <dependency> | ||||
|                 <groupId>com.mashape.unirest</groupId> | ||||
|                 <artifactId>unirest-java</artifactId> | ||||
|                 <version>${unirest.version}</version> | ||||
|             </dependency> | ||||
|         </dependencies> | ||||
|     </dependencyManagement> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -45,7 +45,6 @@ | ||||
|         <module>datavec-data</module> | ||||
|         <module>datavec-spark</module> | ||||
|         <module>datavec-local</module> | ||||
|         <module>datavec-spark-inference-parent</module> | ||||
|         <module>datavec-jdbc</module> | ||||
|         <module>datavec-excel</module> | ||||
|         <module>datavec-arrow</module> | ||||
|  | ||||
| @ -1,143 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>deeplearning4j-nearestneighbor-server</artifactId> | ||||
|     <packaging>jar</packaging> | ||||
| 
 | ||||
|     <name>deeplearning4j-nearestneighbor-server</name> | ||||
| 
 | ||||
|     <properties> | ||||
|         <java.compile.version>1.8</java.compile.version> | ||||
|     </properties> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-nearestneighbors-model</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-core</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>io.vertx</groupId> | ||||
|             <artifactId>vertx-core</artifactId> | ||||
|             <version>${vertx.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>io.vertx</groupId> | ||||
|             <artifactId>vertx-web</artifactId> | ||||
|             <version>${vertx.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.mashape.unirest</groupId> | ||||
|             <artifactId>unirest-java</artifactId> | ||||
|             <version>${unirest.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-nearestneighbors-client</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>com.beust</groupId> | ||||
|             <artifactId>jcommander</artifactId> | ||||
|             <version>${jcommander.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>ch.qos.logback</groupId> | ||||
|             <artifactId>logback-classic</artifactId> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-common-tests</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <build> | ||||
|         <plugins> | ||||
|             <plugin> | ||||
|                 <groupId>org.apache.maven.plugins</groupId> | ||||
|                 <artifactId>maven-surefire-plugin</artifactId> | ||||
|                 <configuration> | ||||
|                     <argLine>-Dfile.encoding=UTF-8 -Xmx8g</argLine> | ||||
|                     <includes> | ||||
|                         <!-- Default setting only runs tests that start/end with "Test" --> | ||||
|                         <include>*.java</include> | ||||
|                         <include>**/*.java</include> | ||||
|                     </includes> | ||||
|                 </configuration> | ||||
|             </plugin> | ||||
|             <plugin> | ||||
|                 <groupId>org.apache.maven.plugins</groupId> | ||||
|                 <artifactId>maven-compiler-plugin</artifactId> | ||||
|                 <configuration> | ||||
|                     <source>${java.compile.version}</source> | ||||
|                     <target>${java.compile.version}</target> | ||||
|                 </configuration> | ||||
|             </plugin> | ||||
|         </plugins> | ||||
|     </build> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-native</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-cuda-11.0</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,67 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.server; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.deeplearning4j.clustering.vptree.VPTree; | ||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; | ||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| @Builder | ||||
| public class NearestNeighbor { | ||||
|     private NearestNeighborRequest record; | ||||
|     private VPTree tree; | ||||
|     private INDArray points; | ||||
| 
 | ||||
|     public List<NearestNeighborsResult> search() { | ||||
|         INDArray input = points.slice(record.getInputIndex()); | ||||
|         List<NearestNeighborsResult> results = new ArrayList<>(); | ||||
|         if (input.isVector()) { | ||||
|             List<DataPoint> add = new ArrayList<>(); | ||||
|             List<Double> distances = new ArrayList<>(); | ||||
|             tree.search(input, record.getK(), add, distances); | ||||
| 
 | ||||
|             if (add.size() != distances.size()) { | ||||
|                 throw new IllegalStateException( | ||||
|                         String.format("add.size == %d != %d == distances.size", | ||||
|                                 add.size(), distances.size())); | ||||
|             } | ||||
| 
 | ||||
|             for (int i=0; i<add.size(); i++) { | ||||
|                 results.add(new NearestNeighborsResult(add.get(i).getIndex(), distances.get(i))); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         return results; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,278 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.server; | ||||
| 
 | ||||
| import com.beust.jcommander.JCommander; | ||||
| import com.beust.jcommander.Parameter; | ||||
| import com.beust.jcommander.ParameterException; | ||||
| import io.netty.handler.codec.http.HttpResponseStatus; | ||||
| import io.vertx.core.AbstractVerticle; | ||||
| import io.vertx.core.Vertx; | ||||
| import io.vertx.ext.web.Router; | ||||
| import io.vertx.ext.web.handler.BodyHandler; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.apache.commons.io.FileUtils; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.deeplearning4j.clustering.vptree.VPTree; | ||||
| import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; | ||||
| import org.deeplearning4j.exception.DL4JInvalidInputException; | ||||
| import org.deeplearning4j.nearestneighbor.model.*; | ||||
| import org.deeplearning4j.nn.conf.serde.JsonMappers; | ||||
| import org.nd4j.linalg.api.buffer.DataBuffer; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.shape.Shape; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.NDArrayIndex; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| import org.nd4j.serde.binary.BinarySerde; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Slf4j | ||||
| public class NearestNeighborsServer extends AbstractVerticle { | ||||
| 
 | ||||
|     private static class RunArgs { | ||||
|         @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) | ||||
|         private String ndarrayPath = null; | ||||
|         @Parameter(names = {"--labelsPath"}, arity = 1, required = false) | ||||
|         private String labelsPath = null; | ||||
|         @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) | ||||
|         private int port = 9000; | ||||
|         @Parameter(names = {"--similarityFunction"}, arity = 1) | ||||
|         private String similarityFunction = "euclidean"; | ||||
|         @Parameter(names = {"--invert"}, arity = 1) | ||||
|         private boolean invert = false; | ||||
|     } | ||||
| 
 | ||||
|     private static RunArgs instanceArgs; | ||||
|     private static NearestNeighborsServer instance; | ||||
| 
 | ||||
|     public NearestNeighborsServer(){ } | ||||
| 
 | ||||
|     public static NearestNeighborsServer getInstance(){ | ||||
|         return instance; | ||||
|     } | ||||
| 
 | ||||
|     public static void runMain(String... args) { | ||||
|         RunArgs r = new RunArgs(); | ||||
|         JCommander jcmdr = new JCommander(r); | ||||
| 
 | ||||
|         try { | ||||
|             jcmdr.parse(args); | ||||
|         } catch (ParameterException e) { | ||||
|             log.error("Error in NearestNeighboursServer parameters", e); | ||||
|             StringBuilder sb = new StringBuilder(); | ||||
|             jcmdr.usage(sb); | ||||
|             log.error("Usage: {}", sb.toString()); | ||||
| 
 | ||||
|             //User provides invalid input -> print the usage info | ||||
|             jcmdr.usage(); | ||||
|             if (r.ndarrayPath == null) | ||||
|                 log.error("Json path parameter is missing (null)"); | ||||
|             try { | ||||
|                 Thread.sleep(500); | ||||
|             } catch (Exception e2) { | ||||
|             } | ||||
|             System.exit(1); | ||||
|         } | ||||
| 
 | ||||
|         instanceArgs = r; | ||||
|         try { | ||||
|             Vertx vertx = Vertx.vertx(); | ||||
|             vertx.deployVerticle(NearestNeighborsServer.class.getName()); | ||||
|         } catch (Throwable t){ | ||||
|             log.error("Error in NearestNeighboursServer run method",t); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public void start() throws Exception { | ||||
|         instance = this; | ||||
| 
 | ||||
|         String[] pathArr = instanceArgs.ndarrayPath.split(","); | ||||
|         //INDArray[] pointsArr = new INDArray[pathArr.length]; | ||||
|         // first of all we reading shapes of saved eariler files | ||||
|         int rows = 0; | ||||
|         int cols = 0; | ||||
|         for (int i = 0; i < pathArr.length; i++) { | ||||
|             DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); | ||||
| 
 | ||||
|             log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), | ||||
|                     Shape.size(shape, 1)); | ||||
| 
 | ||||
|             if (Shape.rank(shape) != 2) | ||||
|                 throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); | ||||
| 
 | ||||
|             rows += Shape.size(shape, 0); | ||||
| 
 | ||||
|             if (cols == 0) | ||||
|                 cols = Shape.size(shape, 1); | ||||
|             else if (cols != Shape.size(shape, 1)) | ||||
|                 throw new DL4JInvalidInputException( | ||||
|                         "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); | ||||
|         } | ||||
| 
 | ||||
|         final List<String> labels = new ArrayList<>(); | ||||
|         if (instanceArgs.labelsPath != null) { | ||||
|             String[] labelsPathArr = instanceArgs.labelsPath.split(","); | ||||
|             for (int i = 0; i < labelsPathArr.length; i++) { | ||||
|                 labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); | ||||
|             } | ||||
|         } | ||||
|         if (!labels.isEmpty() && labels.size() != rows) | ||||
|             throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size())); | ||||
| 
 | ||||
|         final INDArray points = Nd4j.createUninitialized(rows, cols); | ||||
| 
 | ||||
|         int lastPosition = 0; | ||||
|         for (int i = 0; i < pathArr.length; i++) { | ||||
|             log.info("Loading chunk {} of {}", i + 1, pathArr.length); | ||||
|             INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i])); | ||||
| 
 | ||||
|             points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr); | ||||
|             lastPosition += pointsArr.rows(); | ||||
| 
 | ||||
|             // let's ensure we don't bring too much stuff in next loop | ||||
|             System.gc(); | ||||
|         } | ||||
| 
 | ||||
|         VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); | ||||
| 
 | ||||
|         //Set play secret key, if required | ||||
|         //http://www.playframework.com/documentation/latest/ApplicationSecret | ||||
|         String crypto = System.getProperty("play.crypto.secret"); | ||||
|         if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { | ||||
|             byte[] newCrypto = new byte[1024]; | ||||
| 
 | ||||
|             new Random().nextBytes(newCrypto); | ||||
| 
 | ||||
|             String base64 = Base64.getEncoder().encodeToString(newCrypto); | ||||
|             System.setProperty("play.crypto.secret", base64); | ||||
|         } | ||||
| 
 | ||||
|         Router r = Router.router(vertx); | ||||
|         r.route().handler(BodyHandler.create());  //NOTE: Setting this is required to receive request body content at all | ||||
|         createRoutes(r, labels, tree, points); | ||||
| 
 | ||||
|         vertx.createHttpServer() | ||||
|                 .requestHandler(r) | ||||
|                 .listen(instanceArgs.port); | ||||
|     } | ||||
| 
 | ||||
|     private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){ | ||||
| 
 | ||||
|         r.post("/knn").handler(rc -> { | ||||
|             try { | ||||
|                 String json = rc.getBodyAsJson().encode(); | ||||
|                 NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); | ||||
| 
 | ||||
|                 NearestNeighbor nearestNeighbor = | ||||
|                         NearestNeighbor.builder().points(points).record(record).tree(tree).build(); | ||||
| 
 | ||||
|                 if (record == null) { | ||||
|                     rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) | ||||
|                             .putHeader("content-type", "application/json") | ||||
|                             .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); | ||||
|                     return; | ||||
|                 } | ||||
| 
 | ||||
|                 NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); | ||||
| 
 | ||||
|                 rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) | ||||
|                         .putHeader("content-type", "application/json") | ||||
|                         .end(JsonMappers.getMapper().writeValueAsString(results)); | ||||
|                 return; | ||||
|             } catch (Throwable e) { | ||||
|                 log.error("Error in POST /knn",e); | ||||
|                 rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) | ||||
|                         .end("Error parsing request - " + e.getMessage()); | ||||
|                 return; | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         r.post("/knnnew").handler(rc -> { | ||||
|             try { | ||||
|                 String json = rc.getBodyAsJson().encode(); | ||||
|                 Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); | ||||
|                 if (record == null) { | ||||
|                     rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) | ||||
|                             .putHeader("content-type", "application/json") | ||||
|                             .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); | ||||
|                     return; | ||||
|                 } | ||||
| 
 | ||||
|                 INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); | ||||
|                 List<DataPoint> results; | ||||
|                 List<Double> distances; | ||||
| 
 | ||||
|                 if (record.isForceFillK()) { | ||||
|                     VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr); | ||||
|                     vpTreeFillSearch.search(); | ||||
|                     results = vpTreeFillSearch.getResults(); | ||||
|                     distances = vpTreeFillSearch.getDistances(); | ||||
|                 } else { | ||||
|                     results = new ArrayList<>(); | ||||
|                     distances = new ArrayList<>(); | ||||
|                     tree.search(arr, record.getK(), results, distances); | ||||
|                 } | ||||
| 
 | ||||
|                 if (results.size() != distances.size()) { | ||||
|                     rc.response() | ||||
|                             .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) | ||||
|                             .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); | ||||
|                     return; | ||||
|                 } | ||||
| 
 | ||||
|                 List<NearestNeighborsResult> nnResult = new ArrayList<>(); | ||||
|                 for (int i=0; i<results.size(); i++) { | ||||
|                     if (!labels.isEmpty()) | ||||
|                         nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i), labels.get(results.get(i).getIndex()))); | ||||
|                     else | ||||
|                         nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i))); | ||||
|                 } | ||||
| 
 | ||||
|                 NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); | ||||
|                 String j = JsonMappers.getMapper().writeValueAsString(results2); | ||||
|                 rc.response() | ||||
|                         .putHeader("content-type", "application/json") | ||||
|                         .end(j); | ||||
|             } catch (Throwable e) { | ||||
|                 log.error("Error in POST /knnnew",e); | ||||
|                 rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) | ||||
|                         .end("Error parsing request - " + e.getMessage()); | ||||
|                 return; | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Stop the server | ||||
|      */ | ||||
|     public void stop() throws Exception { | ||||
|         super.stop(); | ||||
|     } | ||||
| 
 | ||||
|     public static void main(String[] args) throws Exception { | ||||
|         runMain(args); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,161 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.server; | ||||
| 
 | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.deeplearning4j.clustering.vptree.VPTree; | ||||
| import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; | ||||
| import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient; | ||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; | ||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; | ||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults; | ||||
| import org.junit.Rule; | ||||
| import org.junit.Test; | ||||
| import org.junit.rules.TemporaryFolder; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.serde.binary.BinarySerde; | ||||
| 
 | ||||
| import java.io.File; | ||||
| import java.io.IOException; | ||||
| import java.net.ServerSocket; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.Executor; | ||||
| import java.util.concurrent.Executors; | ||||
| 
 | ||||
| import static org.junit.Assert.assertEquals; | ||||
| 
 | ||||
| public class NearestNeighborTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Rule | ||||
|     public TemporaryFolder testDir = new TemporaryFolder(); | ||||
| 
 | ||||
|     @Test | ||||
|     public void testNearestNeighbor() { | ||||
|         double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; | ||||
|         INDArray arr = Nd4j.create(data); | ||||
| 
 | ||||
|         VPTree vpTree = new VPTree(arr, false); | ||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); | ||||
|         request.setK(2); | ||||
|         request.setInputIndex(0); | ||||
|         NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); | ||||
|         List<NearestNeighborsResult> results = nearestNeighbor.search(); | ||||
|         assertEquals(1, results.get(0).getIndex()); | ||||
|         assertEquals(2, results.size()); | ||||
| 
 | ||||
|         assertEquals(1.0, results.get(0).getDistance(), 1e-4); | ||||
|         assertEquals(4.0, results.get(1).getDistance(), 1e-4); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testNearestNeighborInverted() { | ||||
|         double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; | ||||
|         INDArray arr = Nd4j.create(data); | ||||
| 
 | ||||
|         VPTree vpTree = new VPTree(arr, true); | ||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); | ||||
|         request.setK(2); | ||||
|         request.setInputIndex(0); | ||||
|         NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); | ||||
|         List<NearestNeighborsResult> results = nearestNeighbor.search(); | ||||
|         assertEquals(2, results.get(0).getIndex()); | ||||
|         assertEquals(2, results.size()); | ||||
| 
 | ||||
|         assertEquals(-4.0, results.get(0).getDistance(), 1e-4); | ||||
|         assertEquals(-1.0, results.get(1).getDistance(), 1e-4); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void vpTreeTest() throws Exception { | ||||
|         INDArray matrix = Nd4j.rand(new int[] {400,10}); | ||||
|         INDArray rowVector = matrix.getRow(70); | ||||
|         INDArray resultArr = Nd4j.zeros(400,1); | ||||
|         Executor executor = Executors.newSingleThreadExecutor(); | ||||
|         VPTree vpTree = new VPTree(matrix); | ||||
|         System.out.println("Ran!"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public static int getAvailablePort() { | ||||
|         try { | ||||
|             ServerSocket socket = new ServerSocket(0); | ||||
|             try { | ||||
|                 return socket.getLocalPort(); | ||||
|             } finally { | ||||
|                 socket.close(); | ||||
|             } | ||||
|         } catch (IOException e) { | ||||
|             throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testServer() throws Exception { | ||||
|         int localPort = getAvailablePort(); | ||||
|         Nd4j.getRandom().setSeed(7); | ||||
|         INDArray rand = Nd4j.randn(10, 5); | ||||
|         File writeToTmp = testDir.newFile(); | ||||
|         writeToTmp.deleteOnExit(); | ||||
|         BinarySerde.writeArrayToDisk(rand, writeToTmp); | ||||
|         NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", | ||||
|                 String.valueOf(localPort)); | ||||
| 
 | ||||
|         Thread.sleep(3000); | ||||
| 
 | ||||
|         NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); | ||||
|         NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); | ||||
|         assertEquals(5, result.getResults().size()); | ||||
|         NearestNeighborsServer.getInstance().stop(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testFullSearch() throws Exception { | ||||
|         int numRows = 1000; | ||||
|         int numCols = 100; | ||||
|         int numNeighbors = 42; | ||||
|         INDArray points = Nd4j.rand(numRows, numCols); | ||||
|         VPTree tree = new VPTree(points); | ||||
|         INDArray query = Nd4j.rand(new int[] {1, numCols}); | ||||
|         VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query); | ||||
|         fillSearch.search(); | ||||
|         List<DataPoint> results = fillSearch.getResults(); | ||||
|         List<Double> distances = fillSearch.getDistances(); | ||||
|         assertEquals(numNeighbors, distances.size()); | ||||
|         assertEquals(numNeighbors, results.size()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDistances() { | ||||
| 
 | ||||
|         INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}}); | ||||
|         INDArray record = Nd4j.create(new float[][]{{7, 6}}); | ||||
|         VPTree vpTree = new VPTree(indArray, "euclidean", false); | ||||
|         VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record); | ||||
|         vpTreeFillSearch.search(); | ||||
|         //System.out.println(vpTreeFillSearch.getResults()); | ||||
|         System.out.println(vpTreeFillSearch.getDistances()); | ||||
|     } | ||||
| } | ||||
| @ -1,46 +0,0 @@ | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <configuration> | ||||
| 
 | ||||
|     <appender name="FILE" class="ch.qos.logback.core.FileAppender"> | ||||
|         <file>logs/application.log</file> | ||||
|         <encoder> | ||||
|             <pattern> %logger{15} - %message%n%xException{5} | ||||
|             </pattern> | ||||
|         </encoder> | ||||
|     </appender> | ||||
| 
 | ||||
|     <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> | ||||
|         <encoder> | ||||
|             <pattern> %logger{15} - %message%n%xException{5} | ||||
|             </pattern> | ||||
|         </encoder> | ||||
|     </appender> | ||||
| 
 | ||||
|     <logger name="org.deeplearning4j" level="INFO" /> | ||||
|     <logger name="org.datavec" level="INFO" /> | ||||
|     <logger name="org.nd4j" level="INFO" /> | ||||
| 
 | ||||
|     <root level="ERROR"> | ||||
|         <appender-ref ref="STDOUT" /> | ||||
|         <appender-ref ref="FILE" /> | ||||
|     </root> | ||||
| </configuration> | ||||
| @ -1,60 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>deeplearning4j-nearestneighbors-client</artifactId> | ||||
|     <packaging>jar</packaging> | ||||
| 
 | ||||
|     <name>deeplearning4j-nearestneighbors-client</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>com.mashape.unirest</groupId> | ||||
|             <artifactId>unirest-java</artifactId> | ||||
|             <version>${unirest.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-nearestneighbors-model</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,137 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.client; | ||||
| 
 | ||||
| import com.mashape.unirest.http.ObjectMapper; | ||||
| import com.mashape.unirest.http.Unirest; | ||||
| import com.mashape.unirest.request.HttpRequest; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import lombok.val; | ||||
| import org.deeplearning4j.nearestneighbor.model.*; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.serde.base64.Nd4jBase64; | ||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; | ||||
| 
 | ||||
| import java.io.IOException; | ||||
| 
 | ||||
| @AllArgsConstructor | ||||
| public class NearestNeighborsClient { | ||||
| 
 | ||||
|     private String url; | ||||
|     @Setter | ||||
|     @Getter | ||||
|     protected String authToken; | ||||
| 
 | ||||
|     public NearestNeighborsClient(String url){ | ||||
|         this(url, null); | ||||
|     } | ||||
| 
 | ||||
|     static { | ||||
|         // Only one time | ||||
| 
 | ||||
|         Unirest.setObjectMapper(new ObjectMapper() { | ||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = | ||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); | ||||
| 
 | ||||
|             public <T> T readValue(String value, Class<T> valueType) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.readValue(value, valueType); | ||||
|                 } catch (IOException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             public String writeValue(Object value) { | ||||
|                 try { | ||||
|                     return jacksonObjectMapper.writeValueAsString(value); | ||||
|                 } catch (JsonProcessingException e) { | ||||
|                     throw new RuntimeException(e); | ||||
|                 } | ||||
|             } | ||||
|         }); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Runs knn on the given index | ||||
|      * with the given k (note that this is for data | ||||
|      * already within the existing dataset not new data) | ||||
|      * @param index the index of the | ||||
|      *              EXISTING ndarray | ||||
|      *              to run a search on | ||||
|      * @param k the number of results | ||||
|      * @return | ||||
|      * @throws Exception | ||||
|      */ | ||||
|     public NearestNeighborsResults knn(int index, int k) throws Exception { | ||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); | ||||
|         request.setInputIndex(index); | ||||
|         request.setK(k); | ||||
|         val req = Unirest.post(url + "/knn"); | ||||
|         req.header("accept", "application/json") | ||||
|                 .header("Content-Type", "application/json").body(request); | ||||
|         addAuthHeader(req); | ||||
| 
 | ||||
|         NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Run a k nearest neighbors search | ||||
|      * on a NEW data point | ||||
|      * @param k the number of results | ||||
|      *          to retrieve | ||||
|      * @param arr the array to run the search on. | ||||
|      *            Note that this must be a row vector | ||||
|      * @return | ||||
|      * @throws Exception | ||||
|      */ | ||||
|     public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception { | ||||
|         Base64NDArrayBody base64NDArrayBody = | ||||
|                         Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); | ||||
| 
 | ||||
|         val req = Unirest.post(url + "/knnnew"); | ||||
|         req.header("accept", "application/json") | ||||
|                 .header("Content-Type", "application/json").body(base64NDArrayBody); | ||||
|         addAuthHeader(req); | ||||
| 
 | ||||
|         NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Add the specified authentication header to the specified HttpRequest | ||||
|      * | ||||
|      * @param request HTTP Request to add the authentication header to | ||||
|      */ | ||||
|     protected HttpRequest addAuthHeader(HttpRequest request) { | ||||
|         if (authToken != null) { | ||||
|             request.header("authorization", "Bearer " + authToken); | ||||
|         } | ||||
| 
 | ||||
|         return request; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,61 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>deeplearning4j-nearestneighbors-model</artifactId> | ||||
|     <packaging>jar</packaging> | ||||
| 
 | ||||
|     <name>deeplearning4j-nearestneighbors-model</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.projectlombok</groupId> | ||||
|             <artifactId>lombok</artifactId> | ||||
|             <version>${lombok.version}</version> | ||||
|             <scope>provided</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-api</artifactId> | ||||
|             <version>${nd4j.version}</version> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,38 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| @Builder | ||||
| public class Base64NDArrayBody implements Serializable { | ||||
|     private String ndarray; | ||||
|     private int k; | ||||
|     private boolean forceFillK; | ||||
| } | ||||
| @ -1,65 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @Builder | ||||
| @NoArgsConstructor | ||||
| public class BatchRecord implements Serializable { | ||||
|     private List<CSVRecord> records; | ||||
| 
 | ||||
|     /** | ||||
|      * Add a record | ||||
|      * @param record | ||||
|      */ | ||||
|     public void add(CSVRecord record) { | ||||
|         if (records == null) | ||||
|             records = new ArrayList<>(); | ||||
|         records.add(record); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Return a batch record based on a dataset | ||||
|      * @param dataSet the dataset to get the batch record for | ||||
|      * @return the batch record | ||||
|      */ | ||||
|     public static BatchRecord fromDataSet(DataSet dataSet) { | ||||
|         BatchRecord batchRecord = new BatchRecord(); | ||||
|         for (int i = 0; i < dataSet.numExamples(); i++) { | ||||
|             batchRecord.add(CSVRecord.fromRow(dataSet.get(i))); | ||||
|         } | ||||
| 
 | ||||
|         return batchRecord; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,85 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.nd4j.linalg.dataset.DataSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class CSVRecord implements Serializable { | ||||
|     private String[] values; | ||||
| 
 | ||||
|     /** | ||||
|      * Instantiate a csv record from a vector | ||||
|      * given either an input dataset and a | ||||
|      * one hot matrix, the index will be appended to | ||||
|      * the end of the record, or for regression | ||||
|      * it will append all values in the labels | ||||
|      * @param row the input vectors | ||||
|      * @return the record from this {@link DataSet} | ||||
|      */ | ||||
|     public static CSVRecord fromRow(DataSet row) { | ||||
|         if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) | ||||
|             throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); | ||||
|         if (!row.getLabels().isVector() && !row.getLabels().isScalar()) | ||||
|             throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); | ||||
|         //classification | ||||
|         CSVRecord record; | ||||
|         int idx = 0; | ||||
|         if (row.getLabels().sumNumber().doubleValue() == 1.0) { | ||||
|             String[] values = new String[row.getFeatures().columns() + 1]; | ||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); | ||||
|             } | ||||
|             int maxIdx = 0; | ||||
|             for (int i = 0; i < row.getLabels().length(); i++) { | ||||
|                 if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { | ||||
|                     maxIdx = i; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             values[idx++] = String.valueOf(maxIdx); | ||||
|             record = new CSVRecord(values); | ||||
|         } | ||||
|         //regression (any number of values) | ||||
|         else { | ||||
|             String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; | ||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); | ||||
|             } | ||||
|             for (int i = 0; i < row.getLabels().length(); i++) { | ||||
|                 values[idx++] = String.valueOf(row.getLabels().getDouble(i)); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             record = new CSVRecord(values); | ||||
| 
 | ||||
|         } | ||||
|         return record; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,32 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| public class NearestNeighborRequest implements Serializable { | ||||
|     private int k; | ||||
|     private int inputIndex; | ||||
| 
 | ||||
| } | ||||
| @ -1,37 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| @Data | ||||
| @AllArgsConstructor | ||||
| @NoArgsConstructor | ||||
| public class NearestNeighborsResult { | ||||
|     public NearestNeighborsResult(int index, double distance) { | ||||
|         this(index, distance, null); | ||||
|     } | ||||
| 
 | ||||
|     private int index; | ||||
|     private double distance; | ||||
|     private String label; | ||||
| } | ||||
| @ -1,38 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.nearestneighbor.model; | ||||
| 
 | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.List; | ||||
| 
 | ||||
| @Data | ||||
| @Builder | ||||
| @NoArgsConstructor | ||||
| @AllArgsConstructor | ||||
| public class NearestNeighborsResults implements Serializable { | ||||
|     private List<NearestNeighborsResult> results; | ||||
| 
 | ||||
| } | ||||
| @ -1,103 +0,0 @@ | ||||
| <?xml version="1.0" encoding="UTF-8"?> | ||||
| <!-- | ||||
|   ~ /* ****************************************************************************** | ||||
|   ~  * | ||||
|   ~  * | ||||
|   ~  * This program and the accompanying materials are made available under the | ||||
|   ~  * terms of the Apache License, Version 2.0 which is available at | ||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|   ~  * | ||||
|   ~  *  See the NOTICE file distributed with this work for additional | ||||
|   ~  *  information regarding copyright ownership. | ||||
|   ~  * Unless required by applicable law or agreed to in writing, software | ||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|   ~  * License for the specific language governing permissions and limitations | ||||
|   ~  * under the License. | ||||
|   ~  * | ||||
|   ~  * SPDX-License-Identifier: Apache-2.0 | ||||
|   ~  ******************************************************************************/ | ||||
|   --> | ||||
| 
 | ||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" | ||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||||
| 
 | ||||
|     <modelVersion>4.0.0</modelVersion> | ||||
| 
 | ||||
|     <parent> | ||||
|         <groupId>org.deeplearning4j</groupId> | ||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> | ||||
|         <version>1.0.0-SNAPSHOT</version> | ||||
|     </parent> | ||||
| 
 | ||||
|     <artifactId>nearestneighbor-core</artifactId> | ||||
|     <packaging>jar</packaging> | ||||
| 
 | ||||
|     <name>nearestneighbor-core</name> | ||||
| 
 | ||||
|     <dependencies> | ||||
|         <dependency> | ||||
|             <groupId>org.nd4j</groupId> | ||||
|             <artifactId>nd4j-api</artifactId> | ||||
|             <version>${nd4j.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>junit</groupId> | ||||
|             <artifactId>junit</artifactId> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>ch.qos.logback</groupId> | ||||
|             <artifactId>logback-classic</artifactId> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-nn</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-datasets</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>joda-time</groupId> | ||||
|             <artifactId>joda-time</artifactId> | ||||
|             <version>2.10.3</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|         <dependency> | ||||
|             <groupId>org.deeplearning4j</groupId> | ||||
|             <artifactId>deeplearning4j-common-tests</artifactId> | ||||
|             <version>${project.version}</version> | ||||
|             <scope>test</scope> | ||||
|         </dependency> | ||||
|     </dependencies> | ||||
| 
 | ||||
|     <profiles> | ||||
|         <profile> | ||||
|             <id>test-nd4j-native</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-native</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|         <profile> | ||||
|             <id>test-nd4j-cuda-11.0</id> | ||||
|             <dependencies> | ||||
|                 <dependency> | ||||
|                     <groupId>org.nd4j</groupId> | ||||
|                     <artifactId>nd4j-cuda-11.0</artifactId> | ||||
|                     <version>${project.version}</version> | ||||
|                     <scope>test</scope> | ||||
|                 </dependency> | ||||
|             </dependencies> | ||||
|         </profile> | ||||
|     </profiles> | ||||
| </project> | ||||
| @ -1,218 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.algorithm; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.NoArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import lombok.val; | ||||
| import org.apache.commons.lang3.ArrayUtils; | ||||
| import org.deeplearning4j.clustering.cluster.Cluster; | ||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; | ||||
| import org.deeplearning4j.clustering.cluster.ClusterUtils; | ||||
| import org.deeplearning4j.clustering.cluster.Point; | ||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| import org.deeplearning4j.clustering.iteration.IterationInfo; | ||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategy; | ||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategyType; | ||||
| import org.deeplearning4j.clustering.strategy.OptimisationStrategy; | ||||
| import org.deeplearning4j.clustering.util.MultiThreadUtils; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.ExecutorService; | ||||
| 
 | ||||
| @Slf4j | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable { | ||||
| 
 | ||||
|     private static final long serialVersionUID = 338231277453149972L; | ||||
| 
 | ||||
|     private ClusteringStrategy clusteringStrategy; | ||||
|     private IterationHistory iterationHistory; | ||||
|     private int currentIteration = 0; | ||||
|     private ClusterSet clusterSet; | ||||
|     private List<Point> initialPoints; | ||||
|     private transient ExecutorService exec; | ||||
|     private boolean useKmeansPlusPlus; | ||||
| 
 | ||||
| 
 | ||||
|     protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { | ||||
|         this.clusteringStrategy = clusteringStrategy; | ||||
|         this.exec = MultiThreadUtils.newExecutorService(); | ||||
|         this.useKmeansPlusPlus = useKmeansPlusPlus; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusteringStrategy | ||||
|      * @return | ||||
|      */ | ||||
|     public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { | ||||
|         return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param points | ||||
|      * @return | ||||
|      */ | ||||
|     public ClusterSet applyTo(List<Point> points) { | ||||
|         resetState(points); | ||||
|         initClusters(useKmeansPlusPlus); | ||||
|         iterations(); | ||||
|         return clusterSet; | ||||
|     } | ||||
| 
 | ||||
|     private void resetState(List<Point> points) { | ||||
|         this.iterationHistory = new IterationHistory(); | ||||
|         this.currentIteration = 0; | ||||
|         this.clusterSet = null; | ||||
|         this.initialPoints = points; | ||||
|     } | ||||
| 
 | ||||
|     /** Run clustering iterations until a | ||||
|      * termination condition is hit. | ||||
|      * This is done by first classifying all points, | ||||
|      * and then updating cluster centers based on | ||||
|      * those classified points | ||||
|      */ | ||||
|     private void iterations() { | ||||
|         int iterationCount = 0; | ||||
|         while ((clusteringStrategy.getTerminationCondition() != null | ||||
|                         && !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory)) | ||||
|                         || iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) { | ||||
|             currentIteration++; | ||||
|             removePoints(); | ||||
|             classifyPoints(); | ||||
|             applyClusteringStrategy(); | ||||
|             log.trace("Completed clustering iteration {}", ++iterationCount); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     protected void classifyPoints() { | ||||
|         //Classify points. This also adds each point to the ClusterSet | ||||
|         ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec); | ||||
|         //Update the cluster centers, based on the points within each cluster | ||||
|         ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec); | ||||
|         iterationHistory.getIterationsInfos().put(currentIteration, | ||||
|                         new IterationInfo(currentIteration, clusterSetInfo)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initialize the | ||||
|      * cluster centers at random | ||||
|      */ | ||||
|     protected void initClusters(boolean kMeansPlusPlus) { | ||||
|         log.info("Generating initial clusters"); | ||||
|         List<Point> points = new ArrayList<>(initialPoints); | ||||
| 
 | ||||
|         //Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly) | ||||
|         val random = Nd4j.getRandom(); | ||||
|         Distance distanceFn = clusteringStrategy.getDistanceFunction(); | ||||
|         int initialClusterCount = clusteringStrategy.getInitialClusterCount(); | ||||
|         clusterSet = new ClusterSet(distanceFn, | ||||
|                         clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()}); | ||||
|         clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size()))); | ||||
| 
 | ||||
| 
 | ||||
|         //dxs: distances between | ||||
|         // each point and nearest cluster to that point | ||||
|         INDArray dxs = Nd4j.create(points.size()); | ||||
|         dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE); | ||||
| 
 | ||||
|         //Generate the initial cluster centers, by randomly selecting a point between 0 and max distance | ||||
|         //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster | ||||
|         while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) { | ||||
|             dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec); | ||||
|             double summed = Nd4j.sum(dxs).getDouble(0); | ||||
|             double r = kMeansPlusPlus ? random.nextDouble() * summed: | ||||
|                                         random.nextFloat() * dxs.maxNumber().doubleValue(); | ||||
| 
 | ||||
|             for (int i = 0; i < dxs.length(); i++) { | ||||
|                 double distance = dxs.getDouble(i); | ||||
|                 Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " + | ||||
|                         "function must return values >= 0, got distance %s for function s", distance, distanceFn); | ||||
|                 if (dxs.getDouble(i) >= r) { | ||||
|                     clusterSet.addNewClusterWithCenter(points.remove(i)); | ||||
|                     dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i)); | ||||
|                     break; | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet); | ||||
|         iterationHistory.getIterationsInfos().put(currentIteration, | ||||
|                         new IterationInfo(currentIteration, initialClusterSetInfo)); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     protected void applyClusteringStrategy() { | ||||
|         if (!isStrategyApplicableNow()) | ||||
|             return; | ||||
| 
 | ||||
|         ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); | ||||
|         if (!clusteringStrategy.isAllowEmptyClusters()) { | ||||
|             int removedCount = removeEmptyClusters(clusterSetInfo); | ||||
|             if (removedCount > 0) { | ||||
|                 iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); | ||||
| 
 | ||||
|                 if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) | ||||
|                                 && clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) { | ||||
|                     int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo, | ||||
|                                     clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec); | ||||
|                     if (splitCount > 0) | ||||
|                         iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) | ||||
|             optimize(); | ||||
|     } | ||||
| 
 | ||||
|     protected void optimize() { | ||||
|         ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); | ||||
|         OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy; | ||||
|         boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec); | ||||
|         iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied); | ||||
|     } | ||||
| 
 | ||||
|     private boolean isStrategyApplicableNow() { | ||||
|         return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0 | ||||
|                         && clusteringStrategy.isOptimizationApplicableNow(iterationHistory); | ||||
|     } | ||||
| 
 | ||||
|     protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) { | ||||
|         List<Cluster> removedClusters = clusterSet.removeEmptyClusters(); | ||||
|         clusterSetInfo.removeClusterInfos(removedClusters); | ||||
|         return removedClusters.size(); | ||||
|     } | ||||
| 
 | ||||
|     protected void removePoints() { | ||||
|         clusterSet.removePoints(); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,38 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.algorithm; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; | ||||
| import org.deeplearning4j.clustering.cluster.Point; | ||||
| 
 | ||||
| import java.util.List; | ||||
| 
 | ||||
| public interface ClusteringAlgorithm { | ||||
| 
 | ||||
|     /** | ||||
|      * Apply a clustering | ||||
|      * algorithm for a given result | ||||
|      * @param points | ||||
|      * @return | ||||
|      */ | ||||
|     ClusterSet applyTo(List<Point> points); | ||||
| 
 | ||||
| } | ||||
| @ -1,41 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.algorithm; | ||||
| 
 | ||||
| public enum Distance { | ||||
|     EUCLIDEAN("euclidean"), | ||||
|     COSINE_DISTANCE("cosinedistance"), | ||||
|     COSINE_SIMILARITY("cosinesimilarity"), | ||||
|     MANHATTAN("manhattan"), | ||||
|     DOT("dot"), | ||||
|     JACCARD("jaccard"), | ||||
|     HAMMING("hamming"); | ||||
| 
 | ||||
|     private String functionName; | ||||
|     private Distance(String name) { | ||||
|         functionName = name; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         return functionName; | ||||
|     } | ||||
| } | ||||
| @ -1,105 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.ReduceOp; | ||||
| import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| public class CentersHolder { | ||||
|     private INDArray centers; | ||||
|     private long index = 0; | ||||
| 
 | ||||
|     protected transient ReduceOp op; | ||||
|     protected ArgMin imin; | ||||
|     protected transient INDArray distances; | ||||
|     protected transient INDArray argMin; | ||||
| 
 | ||||
|     private long rows, cols; | ||||
| 
 | ||||
|     public CentersHolder(long rows, long cols) { | ||||
|         this.rows = rows; | ||||
|         this.cols = cols; | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getCenters() { | ||||
|         return this.centers; | ||||
|     } | ||||
| 
 | ||||
|     public synchronized void addCenter(INDArray pointView) { | ||||
|         if (centers == null) | ||||
|             this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols}); | ||||
| 
 | ||||
|         centers.putRow(index++, pointView); | ||||
|     } | ||||
| 
 | ||||
|     public synchronized Pair<Double, Long> getCenterByMinDistance(Point point, Distance distanceFunction) { | ||||
|         if (distances == null) | ||||
|             distances = Nd4j.create(centers.dataType(), centers.rows()); | ||||
| 
 | ||||
|         if (argMin == null) | ||||
|             argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); | ||||
| 
 | ||||
|         if (op == null) { | ||||
|             op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); | ||||
|             imin = new ArgMin(distances, argMin); | ||||
|             op.setZ(distances); | ||||
|         } | ||||
| 
 | ||||
|         op.setY(point.getArray()); | ||||
| 
 | ||||
|         Nd4j.getExecutioner().exec(op); | ||||
|         Nd4j.getExecutioner().exec(imin); | ||||
| 
 | ||||
|         Pair<Double, Long> result = new Pair<>(); | ||||
|         result.setFirst(distances.getDouble(argMin.getLong(0))); | ||||
|         result.setSecond(argMin.getLong(0)); | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) { | ||||
|         if (distances == null) | ||||
|             distances = Nd4j.create(centers.dataType(), centers.rows()); | ||||
| 
 | ||||
|         if (argMin == null) | ||||
|             argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); | ||||
| 
 | ||||
|         if (op == null) { | ||||
|             op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); | ||||
|             imin = new ArgMin(distances, argMin); | ||||
|             op.setZ(distances); | ||||
|         } | ||||
| 
 | ||||
|         op.setY(point.getArray()); | ||||
| 
 | ||||
|         Nd4j.getExecutioner().exec(op); | ||||
|         Nd4j.getExecutioner().exec(imin); | ||||
| 
 | ||||
|         System.out.println(distances); | ||||
|         return distances; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,150 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.List; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| @Data | ||||
| public class Cluster implements Serializable { | ||||
| 
 | ||||
|     private String id = UUID.randomUUID().toString(); | ||||
|     private String label; | ||||
| 
 | ||||
|     private Point center; | ||||
|     private List<Point> points = Collections.synchronizedList(new ArrayList<Point>()); | ||||
|     private boolean inverse = false; | ||||
|     private Distance distanceFunction; | ||||
| 
 | ||||
|     public Cluster() { | ||||
|         super(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param center | ||||
|      * @param distanceFunction | ||||
|      */ | ||||
|     public Cluster(Point center, Distance distanceFunction) { | ||||
|         this(center, false, distanceFunction); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param center | ||||
|      * @param distanceFunction | ||||
|      */ | ||||
|     public Cluster(Point center, boolean inverse, Distance distanceFunction) { | ||||
|         this.distanceFunction = distanceFunction; | ||||
|         this.inverse = inverse; | ||||
|         setCenter(center); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the distance to the given | ||||
|      * point from the cluster | ||||
|      * @param point the point to get the distance for | ||||
|      * @return | ||||
|      */ | ||||
|     public double getDistanceToCenter(Point point) { | ||||
|         return Nd4j.getExecutioner().execAndReturn( | ||||
|                         ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Add a point to the cluster | ||||
|      * @param point | ||||
|      */ | ||||
|     public void addPoint(Point point) { | ||||
|         addPoint(point, true); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Add a point to the cluster | ||||
|      * @param point the point to add | ||||
|      * @param moveClusterCenter whether to update | ||||
|      *                          the cluster centroid or not | ||||
|      */ | ||||
|     public void addPoint(Point point, boolean moveClusterCenter) { | ||||
|         if (moveClusterCenter) { | ||||
|             if (isInverse()) { | ||||
|                 center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1); | ||||
|             } else { | ||||
|                 center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         getPoints().add(point); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Clear out the ponits | ||||
|      */ | ||||
|     public void removePoints() { | ||||
|         if (getPoints() != null) | ||||
|             getPoints().clear(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Whether the cluster is empty or not | ||||
|      * @return | ||||
|      */ | ||||
|     public boolean isEmpty() { | ||||
|         return points == null || points.isEmpty(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Return the point with the given id | ||||
|      * @param id | ||||
|      * @return | ||||
|      */ | ||||
|     public Point getPoint(String id) { | ||||
|         for (Point point : points) | ||||
|             if (id.equals(point.getId())) | ||||
|                 return point; | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Remove the point and return it | ||||
|      * @param id | ||||
|      * @return | ||||
|      */ | ||||
|     public Point removePoint(String id) { | ||||
|         Point removePoint = null; | ||||
|         for (Point point : points) | ||||
|             if (id.equals(point.getId())) | ||||
|                 removePoint = point; | ||||
|         if (removePoint != null) | ||||
|             points.remove(removePoint); | ||||
|         return removePoint; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,259 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.*; | ||||
| 
 | ||||
| @Data | ||||
| public class ClusterSet implements Serializable { | ||||
| 
 | ||||
|     private Distance distanceFunction; | ||||
|     private List<Cluster> clusters; | ||||
|     private CentersHolder centersHolder; | ||||
|     private Map<String, String> pointDistribution; | ||||
|     private boolean inverse; | ||||
| 
 | ||||
|     public ClusterSet(boolean inverse) { | ||||
|         this(null, inverse, null); | ||||
|     } | ||||
| 
 | ||||
|     public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) { | ||||
|         this.distanceFunction = distanceFunction; | ||||
|         this.inverse = inverse; | ||||
|         this.clusters = Collections.synchronizedList(new ArrayList<Cluster>()); | ||||
|         this.pointDistribution = Collections.synchronizedMap(new HashMap<String, String>()); | ||||
|         if (shape != null) | ||||
|             this.centersHolder = new CentersHolder(shape[0], shape[1]); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public boolean isInverse() { | ||||
|         return inverse; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param center | ||||
|      * @return | ||||
|      */ | ||||
|     public Cluster addNewClusterWithCenter(Point center) { | ||||
|         Cluster newCluster = new Cluster(center, distanceFunction); | ||||
|         getClusters().add(newCluster); | ||||
|         setPointLocation(center, newCluster); | ||||
|         centersHolder.addCenter(center.getArray()); | ||||
|         return newCluster; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param point | ||||
|      * @return | ||||
|      */ | ||||
|     public PointClassification classifyPoint(Point point) { | ||||
|         return classifyPoint(point, true); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param points | ||||
|      */ | ||||
|     public void classifyPoints(List<Point> points) { | ||||
|         classifyPoints(points, true); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param points | ||||
|      * @param moveClusterCenter | ||||
|      */ | ||||
|     public void classifyPoints(List<Point> points, boolean moveClusterCenter) { | ||||
|         for (Point point : points) | ||||
|             classifyPoint(point, moveClusterCenter); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param point | ||||
|      * @param moveClusterCenter | ||||
|      * @return | ||||
|      */ | ||||
|     public PointClassification classifyPoint(Point point, boolean moveClusterCenter) { | ||||
|         Pair<Cluster, Double> nearestCluster = nearestCluster(point); | ||||
|         Cluster newCluster = nearestCluster.getKey(); | ||||
|         boolean locationChange = isPointLocationChange(point, newCluster); | ||||
|         addPointToCluster(point, newCluster, moveClusterCenter); | ||||
|         return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange); | ||||
|     } | ||||
| 
 | ||||
|     private boolean isPointLocationChange(Point point, Cluster newCluster) { | ||||
|         if (!getPointDistribution().containsKey(point.getId())) | ||||
|             return true; | ||||
|         return !getPointDistribution().get(point.getId()).equals(newCluster.getId()); | ||||
|     } | ||||
| 
 | ||||
|     private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) { | ||||
|         cluster.addPoint(point, moveClusterCenter); | ||||
|         setPointLocation(point, cluster); | ||||
|     } | ||||
| 
 | ||||
|     private void setPointLocation(Point point, Cluster cluster) { | ||||
|         pointDistribution.put(point.getId(), cluster.getId()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param point | ||||
|      * @return | ||||
|      */ | ||||
|     public Pair<Cluster, Double> nearestCluster(Point point) { | ||||
| 
 | ||||
|         /*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE; | ||||
| 
 | ||||
|         double currentDistance; | ||||
|         for (Cluster cluster : getClusters()) { | ||||
|             currentDistance = cluster.getDistanceToCenter(point); | ||||
|             if (isInverse()) { | ||||
|                 if (currentDistance > minDistance) { | ||||
|                     minDistance = currentDistance; | ||||
|                     nearestCluster = cluster; | ||||
|                 } | ||||
|             } else { | ||||
|                 if (currentDistance < minDistance) { | ||||
|                     minDistance = currentDistance; | ||||
|                     nearestCluster = cluster; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|         }*/ | ||||
| 
 | ||||
|         Pair<Double, Long> nearestCenterData = centersHolder. | ||||
|                 getCenterByMinDistance(point, distanceFunction); | ||||
|         Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue()); | ||||
|         double minDistance = nearestCenterData.getFirst(); | ||||
|         return Pair.of(nearestCluster, minDistance); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param m1 | ||||
|      * @param m2 | ||||
|      * @return | ||||
|      */ | ||||
|     public double getDistance(Point m1, Point m2) { | ||||
|         return Nd4j.getExecutioner() | ||||
|                         .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param point | ||||
|      * @return | ||||
|      */ | ||||
|     /*public double getDistanceFromNearestCluster(Point point) { | ||||
|         return nearestCluster(point).getValue(); | ||||
|     }*/ | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterId | ||||
|      * @return | ||||
|      */ | ||||
|     public String getClusterCenterId(String clusterId) { | ||||
|         Point clusterCenter = getClusterCenter(clusterId); | ||||
|         return clusterCenter == null ? null : clusterCenter.getId(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterId | ||||
|      * @return | ||||
|      */ | ||||
|     public Point getClusterCenter(String clusterId) { | ||||
|         Cluster cluster = getCluster(clusterId); | ||||
|         return cluster == null ? null : cluster.getCenter(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param id | ||||
|      * @return | ||||
|      */ | ||||
|     public Cluster getCluster(String id) { | ||||
|         for (int i = 0, j = clusters.size(); i < j; i++) | ||||
|             if (id.equals(clusters.get(i).getId())) | ||||
|                 return clusters.get(i); | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public int getClusterCount() { | ||||
|         return getClusters() == null ? 0 : getClusters().size(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      */ | ||||
|     public void removePoints() { | ||||
|         for (Cluster cluster : getClusters()) | ||||
|             cluster.removePoints(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param count | ||||
|      * @return | ||||
|      */ | ||||
|     public List<Cluster> getMostPopulatedClusters(int count) { | ||||
|         List<Cluster> mostPopulated = new ArrayList<>(clusters); | ||||
|         Collections.sort(mostPopulated, new Comparator<Cluster>() { | ||||
|             public int compare(Cluster o1, Cluster o2) { | ||||
|                 return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); | ||||
|             } | ||||
|         }); | ||||
|         return mostPopulated.subList(0, count); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public List<Cluster> removeEmptyClusters() { | ||||
|         List<Cluster> emptyClusters = new ArrayList<>(); | ||||
|         for (Cluster cluster : clusters) | ||||
|             if (cluster.isEmpty()) | ||||
|                 emptyClusters.add(cluster); | ||||
|         clusters.removeAll(emptyClusters); | ||||
|         return emptyClusters; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,531 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.NoArgsConstructor; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import lombok.val; | ||||
| import org.apache.commons.lang3.ArrayUtils; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.info.ClusterInfo; | ||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; | ||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; | ||||
| import org.deeplearning4j.clustering.strategy.OptimisationStrategy; | ||||
| import org.deeplearning4j.clustering.util.MathUtils; | ||||
| import org.deeplearning4j.clustering.util.MultiThreadUtils; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.ReduceOp; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.*; | ||||
| import java.util.concurrent.ExecutorService; | ||||
| 
 | ||||
| @NoArgsConstructor(access = AccessLevel.PRIVATE) | ||||
| @Slf4j | ||||
| public class ClusterUtils { | ||||
| 
 | ||||
|     /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */ | ||||
|     public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points, | ||||
|                     ExecutorService executorService) { | ||||
|         final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true); | ||||
| 
 | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         for (final Point point : points) { | ||||
|             //tasks.add(new Runnable() { | ||||
|               //  public void run() { | ||||
|                     try { | ||||
|                         PointClassification result = classifyPoint(clusterSet, point); | ||||
|                         if (result.isNewLocation()) | ||||
|                             clusterSetInfo.getPointLocationChange().incrementAndGet(); | ||||
|                         clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter() | ||||
|                                         .put(point.getId(), result.getDistanceFromCenter()); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error classifying point", t); | ||||
|                     } | ||||
|             //    } | ||||
|             } | ||||
| 
 | ||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
|         return clusterSetInfo; | ||||
|     } | ||||
| 
 | ||||
|     public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) { | ||||
|         return clusterSet.classifyPoint(point, false); | ||||
|     } | ||||
| 
 | ||||
|     public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, | ||||
|                     ExecutorService executorService) { | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         int nClusters = clusterSet.getClusterCount(); | ||||
|         for (int i = 0; i < nClusters; i++) { | ||||
|             final Cluster cluster = clusterSet.getClusters().get(i); | ||||
|             //tasks.add(new Runnable() { | ||||
|             //    public void run() { | ||||
|                     try { | ||||
|                         final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); | ||||
|                         refreshClusterCenter(cluster, clusterInfo); | ||||
|                         deriveClusterInfoDistanceStatistics(clusterInfo); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error refreshing cluster centers", t); | ||||
|                     } | ||||
|             //    } | ||||
|             //}); | ||||
|         } | ||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
|     } | ||||
| 
 | ||||
|     public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { | ||||
|         int pointsCount = cluster.getPoints().size(); | ||||
|         if (pointsCount == 0) | ||||
|             return; | ||||
|         Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length())); | ||||
|         for (Point point : cluster.getPoints()) { | ||||
|             INDArray arr = point.getArray(); | ||||
|             if (cluster.isInverse()) | ||||
|                 center.getArray().subi(arr); | ||||
|             else | ||||
|                 center.getArray().addi(arr); | ||||
|         } | ||||
|         center.getArray().divi(pointsCount); | ||||
|         cluster.setCenter(center); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param info | ||||
|      */ | ||||
|     public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) { | ||||
|         int pointCount = info.getPointDistancesFromCenter().size(); | ||||
|         if (pointCount == 0) | ||||
|             return; | ||||
| 
 | ||||
|         double[] distances = | ||||
|                         ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {})); | ||||
|         double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances); | ||||
|         double total = MathUtils.sum(distances); | ||||
|         info.setMaxPointDistanceFromCenter(max); | ||||
|         info.setTotalPointDistanceFromCenter(total); | ||||
|         info.setAveragePointDistanceFromCenter(total / pointCount); | ||||
|         info.setPointDistanceFromCenterVariance(MathUtils.variance(distances)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param points | ||||
|      * @param previousDxs | ||||
|      * @param executorService | ||||
|      * @return | ||||
|      */ | ||||
|     public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet, | ||||
|                     final List<Point> points, INDArray previousDxs, ExecutorService executorService) { | ||||
|         final int pointsCount = points.size(); | ||||
|         final INDArray dxs = Nd4j.create(pointsCount); | ||||
|         final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); | ||||
| 
 | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         for (int i = 0; i < pointsCount; i++) { | ||||
|             final int i2 = i; | ||||
|             //tasks.add(new Runnable() { | ||||
|             //    public void run() { | ||||
|                     try { | ||||
|                         Point point = points.get(i2); | ||||
|                         double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) | ||||
|                                 : Math.pow(newCluster.getDistanceToCenter(point), 2); | ||||
|                         dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error computing squared distance from nearest cluster", t); | ||||
|                     } | ||||
|             //    } | ||||
|             //}); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
|         for (int i = 0; i < pointsCount; i++) { | ||||
|             double previousMinDistance = previousDxs.getDouble(i); | ||||
|             if (clusterSet.isInverse()) { | ||||
|                 if (dxs.getDouble(i) < previousMinDistance) { | ||||
| 
 | ||||
|                     dxs.putScalar(i, previousMinDistance); | ||||
|                 } | ||||
|             } else if (dxs.getDouble(i) > previousMinDistance) | ||||
|                 dxs.putScalar(i, previousMinDistance); | ||||
|         } | ||||
| 
 | ||||
|         return dxs; | ||||
|     } | ||||
| 
 | ||||
|     public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet, | ||||
|                                                                     final List<Point> points, INDArray previousDxs) { | ||||
|         final int pointsCount = points.size(); | ||||
|         final INDArray dxs = Nd4j.create(pointsCount); | ||||
|         final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); | ||||
| 
 | ||||
|         Double sum = new Double(0); | ||||
|         for (int i = 0; i < pointsCount; i++) { | ||||
| 
 | ||||
|                 Point point = points.get(i); | ||||
|                 double dist = Math.pow(newCluster.getDistanceToCenter(point), 2); | ||||
|                 sum += dist; | ||||
|                 dxs.putScalar(i, sum); | ||||
|         } | ||||
| 
 | ||||
|         return dxs; | ||||
|     } | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @return | ||||
|      */ | ||||
|     public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) { | ||||
|         ExecutorService executor = MultiThreadUtils.newExecutorService(); | ||||
|         ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor); | ||||
|         executor.shutdownNow(); | ||||
|         return info; | ||||
|     } | ||||
| 
 | ||||
|     public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) { | ||||
|         final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true); | ||||
|         int clusterCount = clusterSet.getClusterCount(); | ||||
| 
 | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         for (int i = 0; i < clusterCount; i++) { | ||||
|             final Cluster cluster = clusterSet.getClusters().get(i); | ||||
|             //tasks.add(new Runnable() { | ||||
|             //    public void run() { | ||||
|                     try { | ||||
|                         info.getClustersInfos().put(cluster.getId(), | ||||
|                                 computeClusterInfos(cluster, clusterSet.getDistanceFunction())); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error computing cluster set info", t); | ||||
|                     } | ||||
|                 //} | ||||
|             //}); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
| 
 | ||||
|         //tasks = new ArrayList<>(); | ||||
|         for (int i = 0; i < clusterCount; i++) { | ||||
|             final int clusterIdx = i; | ||||
|             final Cluster fromCluster = clusterSet.getClusters().get(i); | ||||
|             //tasks.add(new Runnable() { | ||||
|                 //public void run() { | ||||
|                     try { | ||||
|                         for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { | ||||
|                             Cluster toCluster = clusterSet.getClusters().get(k); | ||||
|                             double distance = Nd4j.getExecutioner() | ||||
|                                             .execAndReturn(ClusterUtils.createDistanceFunctionOp( | ||||
|                                                             clusterSet.getDistanceFunction(), | ||||
|                                                             fromCluster.getCenter().getArray(), | ||||
|                                                             toCluster.getCenter().getArray())) | ||||
|                                             .getFinalResult().doubleValue(); | ||||
|                             info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(), | ||||
|                                             distance); | ||||
|                         } | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error computing distances", t); | ||||
|                     } | ||||
|             //    } | ||||
|             //}); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
| 
 | ||||
|         return info; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param cluster | ||||
|      * @param distanceFunction | ||||
|      * @return | ||||
|      */ | ||||
|     public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) { | ||||
|         ClusterInfo info = new ClusterInfo(cluster.isInverse(), true); | ||||
|         for (int i = 0, j = cluster.getPoints().size(); i < j; i++) { | ||||
|             Point point = cluster.getPoints().get(i); | ||||
|             //shouldn't need to inverse here. other parts of | ||||
|             //the code should interpret the "distance" or score here | ||||
|             double distance = Nd4j.getExecutioner() | ||||
|                             .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, | ||||
|                                             cluster.getCenter().getArray(), point.getArray())) | ||||
|                             .getFinalResult().doubleValue(); | ||||
|             info.getPointDistancesFromCenter().put(point.getId(), distance); | ||||
|             double diff = info.getTotalPointDistanceFromCenter() + distance; | ||||
|             info.setTotalPointDistanceFromCenter(diff); | ||||
|         } | ||||
| 
 | ||||
|         if (!cluster.getPoints().isEmpty()) | ||||
|             info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size()); | ||||
|         return info; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param optimization | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param executor | ||||
|      * @return | ||||
|      */ | ||||
|     public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, | ||||
|                     ClusterSetInfo clusterSetInfo, ExecutorService executor) { | ||||
| 
 | ||||
|         if (optimization.isClusteringOptimizationType( | ||||
|                         ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) { | ||||
|             int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, | ||||
|                             clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); | ||||
|             return splitCount > 0; | ||||
|         } | ||||
| 
 | ||||
|         if (optimization.isClusteringOptimizationType( | ||||
|                         ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) { | ||||
|             int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, | ||||
|                             clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); | ||||
|             return splitCount > 0; | ||||
|         } | ||||
| 
 | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param info | ||||
|      * @param count | ||||
|      * @return | ||||
|      */ | ||||
|     public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, | ||||
|                     int count) { | ||||
|         List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters()); | ||||
|         Collections.sort(clusters, new Comparator<Cluster>() { | ||||
|             public int compare(Cluster o1, Cluster o2) { | ||||
|                 Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter(); | ||||
|                 Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter(); | ||||
|                 int comp = o1TotalDistance.compareTo(o2TotalDistance); | ||||
|                 return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp; | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         return clusters.subList(0, count); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param info | ||||
|      * @param maximumAverageDistance | ||||
|      * @return | ||||
|      */ | ||||
|     public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet, | ||||
|                     final ClusterSetInfo info, double maximumAverageDistance) { | ||||
|         List<Cluster> clusters = new ArrayList<>(); | ||||
|         for (Cluster cluster : clusterSet.getClusters()) { | ||||
|             ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); | ||||
|             if (clusterInfo != null) { | ||||
|                 //distances | ||||
|                 if (clusterInfo.isInverse()) { | ||||
|                     if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance) | ||||
|                         clusters.add(cluster); | ||||
|                 } else { | ||||
|                     if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance) | ||||
|                         clusters.add(cluster); | ||||
|                 } | ||||
| 
 | ||||
|             } | ||||
| 
 | ||||
|         } | ||||
|         return clusters; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param info | ||||
|      * @param maximumDistance | ||||
|      * @return | ||||
|      */ | ||||
|     public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet, | ||||
|                     final ClusterSetInfo info, double maximumDistance) { | ||||
|         List<Cluster> clusters = new ArrayList<>(); | ||||
|         for (Cluster cluster : clusterSet.getClusters()) { | ||||
|             ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); | ||||
|             if (clusterInfo != null) { | ||||
|                 if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) { | ||||
|                     clusters.add(cluster); | ||||
|                 } else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) { | ||||
|                     clusters.add(cluster); | ||||
| 
 | ||||
|                 } | ||||
|             } | ||||
|         } | ||||
|         return clusters; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param count | ||||
|      * @param executorService | ||||
|      * @return | ||||
|      */ | ||||
|     public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, | ||||
|                     ExecutorService executorService) { | ||||
|         List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count); | ||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); | ||||
|         return clustersToSplit.size(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param maxWithinClusterDistance | ||||
|      * @param executorService | ||||
|      * @return | ||||
|      */ | ||||
|     public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, | ||||
|                     ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { | ||||
|         List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, | ||||
|                         maxWithinClusterDistance); | ||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); | ||||
|         return clustersToSplit.size(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param maxWithinClusterDistance | ||||
|      * @param executorService | ||||
|      * @return | ||||
|      */ | ||||
|     public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, | ||||
|                     ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { | ||||
|         List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, | ||||
|                         maxWithinClusterDistance); | ||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); | ||||
|         return clustersToSplit.size(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param count | ||||
|      * @param executorService | ||||
|      */ | ||||
|     public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, | ||||
|                     ExecutorService executorService) { | ||||
|         List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count); | ||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param clusters | ||||
|      * @param maxDistance | ||||
|      * @param executorService | ||||
|      */ | ||||
|     public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, | ||||
|                     List<Cluster> clusters, final double maxDistance, ExecutorService executorService) { | ||||
|         final Random random = new Random(); | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         for (final Cluster cluster : clusters) { | ||||
|             tasks.add(new Runnable() { | ||||
|                 public void run() { | ||||
|                     try { | ||||
|                         ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); | ||||
|                         List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance); | ||||
|                         int rank = Math.min(fartherPoints.size(), 3); | ||||
|                         String pointId = fartherPoints.get(random.nextInt(rank)); | ||||
|                         Point point = cluster.removePoint(pointId); | ||||
|                         clusterSet.addNewClusterWithCenter(point); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error splitting clusters", t); | ||||
|                     } | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
|         MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param clusterSetInfo | ||||
|      * @param clusters | ||||
|      * @param executorService | ||||
|      */ | ||||
|     public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, | ||||
|                     List<Cluster> clusters, ExecutorService executorService) { | ||||
|         final Random random = new Random(); | ||||
|         List<Runnable> tasks = new ArrayList<>(); | ||||
|         for (final Cluster cluster : clusters) { | ||||
|             tasks.add(new Runnable() { | ||||
|                 public void run() { | ||||
|                     try { | ||||
|                         Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size())); | ||||
|                         clusterSet.addNewClusterWithCenter(point); | ||||
|                     } catch (Throwable t) { | ||||
|                         log.warn("Error Splitting clusters (2)", t); | ||||
|                     } | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         MultiThreadUtils.parallelTasks(tasks, executorService); | ||||
|     } | ||||
| 
 | ||||
|     public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){ | ||||
|         val op = createDistanceFunctionOp(distanceFunction, x, y); | ||||
|         op.setDimensions(dimensions); | ||||
|         return op; | ||||
|     } | ||||
| 
 | ||||
|     public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){ | ||||
|         switch (distanceFunction){ | ||||
|             case COSINE_DISTANCE: | ||||
|                 return new CosineDistance(x,y); | ||||
|             case COSINE_SIMILARITY: | ||||
|                 return new CosineSimilarity(x,y); | ||||
|             case DOT: | ||||
|                 return new Dot(x,y); | ||||
|             case EUCLIDEAN: | ||||
|                 return new EuclideanDistance(x,y); | ||||
|             case JACCARD: | ||||
|                 return new JaccardDistance(x,y); | ||||
|             case MANHATTAN: | ||||
|                 return new ManhattanDistance(x,y); | ||||
|             default: | ||||
|                 throw new IllegalStateException("Unknown distance function: " + distanceFunction); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,107 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.UUID; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @Data | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class Point implements Serializable { | ||||
| 
 | ||||
|     private static final long serialVersionUID = -6658028541426027226L; | ||||
| 
 | ||||
|     private String id = UUID.randomUUID().toString(); | ||||
|     private String label; | ||||
|     private INDArray array; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param array | ||||
|      */ | ||||
|     public Point(INDArray array) { | ||||
|         super(); | ||||
|         this.array = array; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param id | ||||
|      * @param array | ||||
|      */ | ||||
|     public Point(String id, INDArray array) { | ||||
|         super(); | ||||
|         this.id = id; | ||||
|         this.array = array; | ||||
|     } | ||||
| 
 | ||||
|     public Point(String id, String label, double[] data) { | ||||
|         this(id, label, Nd4j.create(data)); | ||||
|     } | ||||
| 
 | ||||
|     public Point(String id, String label, INDArray array) { | ||||
|         super(); | ||||
|         this.id = id; | ||||
|         this.label = label; | ||||
|         this.array = array; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param matrix | ||||
|      * @return | ||||
|      */ | ||||
|     public static List<Point> toPoints(INDArray matrix) { | ||||
|         List<Point> arr = new ArrayList<>(matrix.rows()); | ||||
|         for (int i = 0; i < matrix.rows(); i++) { | ||||
|             arr.add(new Point(matrix.getRow(i))); | ||||
|         } | ||||
| 
 | ||||
|         return arr; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param vectors | ||||
|      * @return | ||||
|      */ | ||||
|     public static List<Point> toPoints(List<INDArray> vectors) { | ||||
|         List<Point> points = new ArrayList<>(); | ||||
|         for (INDArray vector : vectors) | ||||
|             points.add(new Point(vector)); | ||||
|         return points; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,40 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| @AllArgsConstructor | ||||
| public class PointClassification implements Serializable { | ||||
| 
 | ||||
|     private Cluster cluster; | ||||
|     private double distanceFromCenter; | ||||
|     private boolean newLocation; | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,37 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.condition; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| public interface ClusteringAlgorithmCondition { | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationHistory | ||||
|      * @return | ||||
|      */ | ||||
|     boolean isSatisfied(IterationHistory iterationHistory); | ||||
| 
 | ||||
| } | ||||
| @ -1,69 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.condition; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| import org.nd4j.linalg.indexing.conditions.Condition; | ||||
| import org.nd4j.linalg.indexing.conditions.LessThan; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| @AllArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable { | ||||
| 
 | ||||
|     private Condition convergenceCondition; | ||||
|     private double pointsDistributionChangeRate; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param pointsDistributionChangeRate | ||||
|      * @return | ||||
|      */ | ||||
|     public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) { | ||||
|         Condition condition = new LessThan(pointsDistributionChangeRate); | ||||
|         return new ConvergenceCondition(condition, pointsDistributionChangeRate); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationHistory | ||||
|      * @return | ||||
|      */ | ||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { | ||||
|         int iterationCount = iterationHistory.getIterationCount(); | ||||
|         if (iterationCount <= 1) | ||||
|             return false; | ||||
| 
 | ||||
|         double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get(); | ||||
|         variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount(); | ||||
| 
 | ||||
|         return convergenceCondition.apply(variation); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,61 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.condition; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| import org.nd4j.linalg.indexing.conditions.Condition; | ||||
| import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable { | ||||
| 
 | ||||
|     private Condition iterationCountCondition; | ||||
| 
 | ||||
|     protected FixedIterationCountCondition(int initialClusterCount) { | ||||
|         iterationCountCondition = new GreaterThanOrEqual(initialClusterCount); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationCount | ||||
|      * @return | ||||
|      */ | ||||
|     public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) { | ||||
|         return new FixedIterationCountCondition(iterationCount); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationHistory | ||||
|      * @return | ||||
|      */ | ||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { | ||||
|         return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount()); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,82 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.condition; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| import org.nd4j.linalg.indexing.conditions.Condition; | ||||
| import org.nd4j.linalg.indexing.conditions.LessThan; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| @AllArgsConstructor | ||||
| public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable { | ||||
| 
 | ||||
|     private Condition varianceVariationCondition; | ||||
|     private int period; | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param varianceVariation | ||||
|      * @param period | ||||
|      * @return | ||||
|      */ | ||||
|     public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) { | ||||
|         Condition condition = new LessThan(varianceVariation); | ||||
|         return new VarianceVariationCondition(condition, period); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationHistory | ||||
|      * @return | ||||
|      */ | ||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { | ||||
|         if (iterationHistory.getIterationCount() <= period) | ||||
|             return false; | ||||
| 
 | ||||
|         for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) { | ||||
|             double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo() | ||||
|                             .getPointDistanceFromClusterVariance(); | ||||
|             variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() | ||||
|                             .getPointDistanceFromClusterVariance(); | ||||
|             variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() | ||||
|                             .getPointDistanceFromClusterVariance(); | ||||
| 
 | ||||
|             if (!varianceVariationCondition.apply(variation)) | ||||
|                 return false; | ||||
|         } | ||||
| 
 | ||||
|         return true; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,114 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.info; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.*; | ||||
| import java.util.concurrent.ConcurrentHashMap; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @Data | ||||
| public class ClusterInfo implements Serializable { | ||||
| 
 | ||||
|     private double averagePointDistanceFromCenter; | ||||
|     private double maxPointDistanceFromCenter; | ||||
|     private double pointDistanceFromCenterVariance; | ||||
|     private double totalPointDistanceFromCenter; | ||||
|     private boolean inverse; | ||||
|     private Map<String, Double> pointDistancesFromCenter = new ConcurrentHashMap<>(); | ||||
| 
 | ||||
|     public ClusterInfo(boolean inverse) { | ||||
|         this(false, inverse); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param threadSafe | ||||
|      */ | ||||
|     public ClusterInfo(boolean threadSafe, boolean inverse) { | ||||
|         super(); | ||||
|         this.inverse = inverse; | ||||
|         if (threadSafe) { | ||||
|             pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public Set<Map.Entry<String, Double>> getSortedPointDistancesFromCenter() { | ||||
|         SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() { | ||||
|             @Override | ||||
|             public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) { | ||||
|                 int res = e1.getValue().compareTo(e2.getValue()); | ||||
|                 return res != 0 ? res : 1; | ||||
|             } | ||||
|         }); | ||||
|         sortedEntries.addAll(pointDistancesFromCenter.entrySet()); | ||||
|         return sortedEntries; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public Set<Map.Entry<String, Double>> getReverseSortedPointDistancesFromCenter() { | ||||
|         SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() { | ||||
|             @Override | ||||
|             public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) { | ||||
|                 int res = e1.getValue().compareTo(e2.getValue()); | ||||
|                 return -(res != 0 ? res : 1); | ||||
|             } | ||||
|         }); | ||||
|         sortedEntries.addAll(pointDistancesFromCenter.entrySet()); | ||||
|         return sortedEntries; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param maxDistance | ||||
|      * @return | ||||
|      */ | ||||
|     public List<String> getPointsFartherFromCenterThan(double maxDistance) { | ||||
|         Set<Map.Entry<String, Double>> sorted = getReverseSortedPointDistancesFromCenter(); | ||||
|         List<String> ids = new ArrayList<>(); | ||||
|         for (Map.Entry<String, Double> entry : sorted) { | ||||
|             if (inverse && entry.getValue() < -maxDistance) { | ||||
|                 if (entry.getValue() < -maxDistance) | ||||
|                     break; | ||||
|             } | ||||
| 
 | ||||
|             else if (entry.getValue() > maxDistance) | ||||
|                 break; | ||||
| 
 | ||||
|             ids.add(entry.getKey()); | ||||
|         } | ||||
|         return ids; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,142 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.info; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.collect.HashBasedTable; | ||||
| import org.nd4j.shade.guava.collect.Table; | ||||
| import org.deeplearning4j.clustering.cluster.Cluster; | ||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.Collections; | ||||
| import java.util.HashMap; | ||||
| import java.util.List; | ||||
| import java.util.Map; | ||||
| import java.util.concurrent.atomic.AtomicInteger; | ||||
| 
 | ||||
| public class ClusterSetInfo implements Serializable { | ||||
| 
 | ||||
|     private Map<String, ClusterInfo> clustersInfos = new HashMap<>(); | ||||
|     private Table<String, String, Double> distancesBetweenClustersCenters = HashBasedTable.create(); | ||||
|     private AtomicInteger pointLocationChange; | ||||
|     private boolean threadSafe; | ||||
|     private boolean inverse; | ||||
| 
 | ||||
|     public ClusterSetInfo(boolean inverse) { | ||||
|         this(inverse, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param inverse | ||||
|      * @param threadSafe | ||||
|      */ | ||||
|     public ClusterSetInfo(boolean inverse, boolean threadSafe) { | ||||
|         this.pointLocationChange = new AtomicInteger(0); | ||||
|         this.threadSafe = threadSafe; | ||||
|         this.inverse = inverse; | ||||
|         if (threadSafe) { | ||||
|             clustersInfos = Collections.synchronizedMap(clustersInfos); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterSet | ||||
|      * @param threadSafe | ||||
|      * @return | ||||
|      */ | ||||
|     public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) { | ||||
|         ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe); | ||||
|         for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++) | ||||
|             info.addClusterInfo(clusterSet.getClusters().get(i).getId()); | ||||
|         return info; | ||||
|     } | ||||
| 
 | ||||
|     public void removeClusterInfos(List<Cluster> clusters) { | ||||
|         for (Cluster cluster : clusters) { | ||||
|             clustersInfos.remove(cluster.getId()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public ClusterInfo addClusterInfo(String clusterId) { | ||||
|         ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe); | ||||
|         clustersInfos.put(clusterId, clusterInfo); | ||||
|         return clusterInfo; | ||||
|     } | ||||
| 
 | ||||
|     public ClusterInfo getClusterInfo(String clusterId) { | ||||
|         return clustersInfos.get(clusterId); | ||||
|     } | ||||
| 
 | ||||
|     public double getAveragePointDistanceFromClusterCenter() { | ||||
|         if (clustersInfos == null || clustersInfos.isEmpty()) | ||||
|             return 0; | ||||
| 
 | ||||
|         double average = 0; | ||||
|         for (ClusterInfo info : clustersInfos.values()) | ||||
|             average += info.getAveragePointDistanceFromCenter(); | ||||
|         return average / clustersInfos.size(); | ||||
|     } | ||||
| 
 | ||||
|     public double getPointDistanceFromClusterVariance() { | ||||
|         if (clustersInfos == null || clustersInfos.isEmpty()) | ||||
|             return 0; | ||||
| 
 | ||||
|         double average = 0; | ||||
|         for (ClusterInfo info : clustersInfos.values()) | ||||
|             average += info.getPointDistanceFromCenterVariance(); | ||||
|         return average / clustersInfos.size(); | ||||
|     } | ||||
| 
 | ||||
|     public int getPointsCount() { | ||||
|         int count = 0; | ||||
|         for (ClusterInfo clusterInfo : clustersInfos.values()) | ||||
|             count += clusterInfo.getPointDistancesFromCenter().size(); | ||||
|         return count; | ||||
|     } | ||||
| 
 | ||||
|     public Map<String, ClusterInfo> getClustersInfos() { | ||||
|         return clustersInfos; | ||||
|     } | ||||
| 
 | ||||
|     public void setClustersInfos(Map<String, ClusterInfo> clustersInfos) { | ||||
|         this.clustersInfos = clustersInfos; | ||||
|     } | ||||
| 
 | ||||
|     public Table<String, String, Double> getDistancesBetweenClustersCenters() { | ||||
|         return distancesBetweenClustersCenters; | ||||
|     } | ||||
| 
 | ||||
|     public void setDistancesBetweenClustersCenters(Table<String, String, Double> interClusterDistances) { | ||||
|         this.distancesBetweenClustersCenters = interClusterDistances; | ||||
|     } | ||||
| 
 | ||||
|     public AtomicInteger getPointLocationChange() { | ||||
|         return pointLocationChange; | ||||
|     } | ||||
| 
 | ||||
|     public void setPointLocationChange(AtomicInteger pointLocationChange) { | ||||
|         this.pointLocationChange = pointLocationChange; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,72 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.iteration; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import lombok.Setter; | ||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.HashMap; | ||||
| import java.util.Map; | ||||
| 
 | ||||
| public class IterationHistory implements Serializable { | ||||
|     @Getter | ||||
|     @Setter | ||||
|     private Map<Integer, IterationInfo> iterationsInfos = new HashMap<>(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public ClusterSetInfo getMostRecentClusterSetInfo() { | ||||
|         IterationInfo iterationInfo = getMostRecentIterationInfo(); | ||||
|         return iterationInfo == null ? null : iterationInfo.getClusterSetInfo(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public IterationInfo getMostRecentIterationInfo() { | ||||
|         return getIterationInfo(getIterationCount() - 1); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public int getIterationCount() { | ||||
|         return getIterationsInfos().size(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationIdx | ||||
|      * @return | ||||
|      */ | ||||
|     public IterationInfo getIterationInfo(int iterationIdx) { | ||||
|         return getIterationsInfos().get(iterationIdx); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,49 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.iteration; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class IterationInfo implements Serializable { | ||||
| 
 | ||||
|     private int index; | ||||
|     private ClusterSetInfo clusterSetInfo; | ||||
|     private boolean strategyApplied; | ||||
| 
 | ||||
|     public IterationInfo(int index) { | ||||
|         super(); | ||||
|         this.index = index; | ||||
|     } | ||||
| 
 | ||||
|     public IterationInfo(int index, ClusterSetInfo clusterSetInfo) { | ||||
|         super(); | ||||
|         this.index = index; | ||||
|         this.clusterSetInfo = clusterSetInfo; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,142 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.kdtree; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.custom.KnnMinDistance; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| public class HyperRect implements Serializable { | ||||
| 
 | ||||
|     //private List<Interval> points; | ||||
|     private float[] lowerEnds; | ||||
|     private float[] higherEnds; | ||||
|     private INDArray lowerEndsIND; | ||||
|     private INDArray higherEndsIND; | ||||
| 
 | ||||
|     public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { | ||||
|         this.lowerEnds = new float[lowerEndsIn.length]; | ||||
|         this.higherEnds = new float[lowerEndsIn.length]; | ||||
|         System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); | ||||
|         System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); | ||||
|         lowerEndsIND = Nd4j.createFromArray(lowerEnds); | ||||
|         higherEndsIND = Nd4j.createFromArray(higherEnds); | ||||
|     } | ||||
| 
 | ||||
|     public HyperRect(float[] point) { | ||||
|         this(point, point); | ||||
|     } | ||||
| 
 | ||||
|     public HyperRect(Pair<float[], float[]> ends) { | ||||
|         this(ends.getFirst(), ends.getSecond()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public void enlargeTo(INDArray point) { | ||||
|         float[] pointAsArray = point.toFloatVector(); | ||||
|         for (int i = 0; i < lowerEnds.length; i++) { | ||||
|             float p = pointAsArray[i]; | ||||
|             if (lowerEnds[i] > p) | ||||
|                 lowerEnds[i] = p; | ||||
|             else if (higherEnds[i] < p) | ||||
|                 higherEnds[i] = p; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     public static Pair<float[],float[]> point(INDArray vector) { | ||||
|         Pair<float[],float[]> ret = new Pair<>(); | ||||
|         float[] curr = new float[(int)vector.length()]; | ||||
|         for (int i = 0; i < vector.length(); i++) { | ||||
|             curr[i] = vector.getFloat(i); | ||||
|         } | ||||
|         ret.setFirst(curr); | ||||
|         ret.setSecond(curr); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /*public List<Boolean> contains(INDArray hPoint) { | ||||
|         List<Boolean> ret = new ArrayList<>(); | ||||
|         for (int i = 0; i < hPoint.length(); i++) { | ||||
|             ret.add(lowerEnds[i] <= hPoint.getDouble(i) && | ||||
|                     higherEnds[i] >= hPoint.getDouble(i)); | ||||
|         } | ||||
|         return ret; | ||||
|     }*/ | ||||
| 
 | ||||
|     public double minDistance(INDArray hPoint, INDArray output) { | ||||
|         Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); | ||||
|         return output.getFloat(0); | ||||
| 
 | ||||
|         /*double ret = 0.0; | ||||
|         double[] pointAsArray = hPoint.toDoubleVector(); | ||||
|         for (int i = 0; i < pointAsArray.length; i++) { | ||||
|            double p = pointAsArray[i]; | ||||
|            if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { | ||||
|               if (p < lowerEnds[i]) | ||||
|                  ret += Math.pow((p - lowerEnds[i]), 2); | ||||
|               else | ||||
|                  ret += Math.pow((p - higherEnds[i]), 2); | ||||
|            } | ||||
|         } | ||||
|         ret = Math.pow(ret, 0.5); | ||||
|         return ret;*/ | ||||
|     } | ||||
| 
 | ||||
|     public HyperRect getUpper(INDArray hPoint, int desc) { | ||||
|         //Interval interval = points.get(desc); | ||||
|         float higher = higherEnds[desc]; | ||||
|         float d = hPoint.getFloat(desc); | ||||
|         if (higher < d) | ||||
|             return null; | ||||
|         HyperRect ret = new HyperRect(lowerEnds,higherEnds); | ||||
|         if (ret.lowerEnds[desc] < d) | ||||
|             ret.lowerEnds[desc] = d; | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     public HyperRect getLower(INDArray hPoint, int desc) { | ||||
|         //Interval interval = points.get(desc); | ||||
|         float lower = lowerEnds[desc]; | ||||
|         float d = hPoint.getFloat(desc); | ||||
|         if (lower > d) | ||||
|             return null; | ||||
|         HyperRect ret = new HyperRect(lowerEnds,higherEnds); | ||||
|         //Interval i2 = ret.points.get(desc); | ||||
|         if (ret.higherEnds[desc] > d) | ||||
|             ret.higherEnds[desc] = d; | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public String toString() { | ||||
|         String retVal = ""; | ||||
|         retVal +=  "["; | ||||
|         for (int i = 0; i < lowerEnds.length; ++i) { | ||||
|             retVal +=  "("  + lowerEnds[i] + " - " + higherEnds[i] + ") "; | ||||
|         } | ||||
|         retVal +=  "]"; | ||||
|         return retVal; | ||||
|     } | ||||
| } | ||||
| @ -1,370 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.kdtree; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collections; | ||||
| import java.util.Comparator; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class KDTree implements Serializable { | ||||
| 
 | ||||
|     private KDNode root; | ||||
|     private int dims = 100; | ||||
|     public final static int GREATER = 1; | ||||
|     public final static int LESS = 0; | ||||
|     private int size = 0; | ||||
|     private HyperRect rect; | ||||
| 
 | ||||
|     public KDTree(int dims) { | ||||
|         this.dims = dims; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Insert a point in to the tree | ||||
|      * @param point the point to insert | ||||
|      */ | ||||
|     public void insert(INDArray point) { | ||||
|         if (!point.isVector() || point.length() != dims) | ||||
|             throw new IllegalArgumentException("Point must be a vector of length " + dims); | ||||
| 
 | ||||
|         if (root == null) { | ||||
|             root = new KDNode(point); | ||||
|             rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); | ||||
|         } else { | ||||
|             int disc = 0; | ||||
|             KDNode node = root; | ||||
|             KDNode insert = new KDNode(point); | ||||
|             int successor; | ||||
|             while (true) { | ||||
|                 //exactly equal | ||||
|                 INDArray pt = node.getPoint(); | ||||
|                 INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z(); | ||||
|                 if (countEq.getInt(0) == 0) { | ||||
|                     return; | ||||
|                 } else { | ||||
|                     successor = successor(node, point, disc); | ||||
|                     KDNode child; | ||||
|                     if (successor < 1) | ||||
|                         child = node.getLeft(); | ||||
|                     else | ||||
|                         child = node.getRight(); | ||||
|                     if (child == null) | ||||
|                         break; | ||||
|                     disc = (disc + 1) % dims; | ||||
|                     node = child; | ||||
|                 } | ||||
|             } | ||||
| 
 | ||||
|             if (successor < 1) | ||||
|                 node.setLeft(insert); | ||||
| 
 | ||||
|             else | ||||
|                 node.setRight(insert); | ||||
| 
 | ||||
|             rect.enlargeTo(point); | ||||
|             insert.setParent(node); | ||||
|         } | ||||
|         size++; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public INDArray delete(INDArray point) { | ||||
|         KDNode node = root; | ||||
|         int _disc = 0; | ||||
|         while (node != null) { | ||||
|             if (node.point == point) | ||||
|                 break; | ||||
|             int successor = successor(node, point, _disc); | ||||
|             if (successor < 1) | ||||
|                 node = node.getLeft(); | ||||
|             else | ||||
|                 node = node.getRight(); | ||||
|             _disc = (_disc + 1) % dims; | ||||
|         } | ||||
| 
 | ||||
|         if (node != null) { | ||||
|             if (node == root) { | ||||
|                 root = delete(root, _disc); | ||||
|             } else | ||||
|                 node = delete(node, _disc); | ||||
|             size--; | ||||
|             if (size == 1) { | ||||
|                 rect = new HyperRect(HyperRect.point(point)); | ||||
|             } else if (size == 0) | ||||
|                 rect = null; | ||||
| 
 | ||||
|         } | ||||
|         return node.getPoint(); | ||||
|     } | ||||
| 
 | ||||
|     // Share this data for recursive calls of "knn" | ||||
|     private float currentDistance; | ||||
|     private INDArray currentPoint; | ||||
|     private INDArray minDistance = Nd4j.scalar(0.f); | ||||
| 
 | ||||
| 
 | ||||
|     public List<Pair<Float, INDArray>> knn(INDArray point, float distance) { | ||||
|         List<Pair<Float, INDArray>> best = new ArrayList<>(); | ||||
|         currentDistance = distance; | ||||
|         currentPoint = point; | ||||
|         knn(root, rect, best, 0); | ||||
|         Collections.sort(best, new Comparator<Pair<Float, INDArray>>() { | ||||
|             @Override | ||||
|             public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) { | ||||
|                 return Float.compare(o1.getKey(), o2.getKey()); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         return best; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) { | ||||
|         if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) | ||||
|             return; | ||||
|         int _discNext = (_disc + 1) % dims; | ||||
|         float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() | ||||
|                 .floatValue(); | ||||
| 
 | ||||
|         if (distance <= currentDistance) { | ||||
|             best.add(Pair.of(distance, node.getPoint())); | ||||
|         } | ||||
| 
 | ||||
|         HyperRect lower = rect.getLower(node.point, _disc); | ||||
|         HyperRect upper = rect.getUpper(node.point, _disc); | ||||
|         knn(node.getLeft(), lower, best, _discNext); | ||||
|         knn(node.getRight(), upper, best, _discNext); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Query for nearest neighbor. Returns the distance and point | ||||
|      * @param point the point to query for | ||||
|      * @return | ||||
|      */ | ||||
|     public Pair<Double, INDArray> nn(INDArray point) { | ||||
|         return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, | ||||
|                     int _disc) { | ||||
|         if (node == null || rect.minDistance(point, minDistance) > dist) | ||||
|             return Pair.of(Double.POSITIVE_INFINITY, null); | ||||
| 
 | ||||
|         int _discNext = (_disc + 1) % dims; | ||||
|         double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue(); | ||||
|         if (dist2 < dist) { | ||||
|             best = node.getPoint(); | ||||
|             dist = dist2; | ||||
|         } | ||||
| 
 | ||||
|         HyperRect lower = rect.getLower(node.point, _disc); | ||||
|         HyperRect upper = rect.getUpper(node.point, _disc); | ||||
| 
 | ||||
|         if (point.getDouble(_disc) < node.point.getDouble(_disc)) { | ||||
|             Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext); | ||||
|             Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext); | ||||
|             if (left.getKey() < dist) | ||||
|                 return left; | ||||
|             else if (right.getKey() < dist) | ||||
|                 return right; | ||||
| 
 | ||||
|         } else { | ||||
|             Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext); | ||||
|             Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext); | ||||
|             if (left.getKey() < dist) | ||||
|                 return left; | ||||
|             else if (right.getKey() < dist) | ||||
|                 return right; | ||||
|         } | ||||
| 
 | ||||
|         return Pair.of(dist, best); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     private KDNode delete(KDNode delete, int _disc) { | ||||
|         if (delete.getLeft() != null && delete.getRight() != null) { | ||||
|             if (delete.getParent() != null) { | ||||
|                 if (delete.getParent().getLeft() == delete) | ||||
|                     delete.getParent().setLeft(null); | ||||
|                 else | ||||
|                     delete.getParent().setRight(null); | ||||
| 
 | ||||
|             } | ||||
|             return null; | ||||
|         } | ||||
| 
 | ||||
|         int disc = _disc; | ||||
|         _disc = (_disc + 1) % dims; | ||||
|         Pair<KDNode, Integer> qd = null; | ||||
|         if (delete.getRight() != null) { | ||||
|             qd = min(delete.getRight(), disc, _disc); | ||||
|         } else if (delete.getLeft() != null) | ||||
|             qd = max(delete.getLeft(), disc, _disc); | ||||
|         if (qd == null) {// is leaf | ||||
|             return null; | ||||
|         } | ||||
|         delete.point = qd.getKey().point; | ||||
|         KDNode qFather = qd.getKey().getParent(); | ||||
|         if (qFather.getLeft() == qd.getKey()) { | ||||
|             qFather.setLeft(delete(qd.getKey(), disc)); | ||||
|         } else if (qFather.getRight() == qd.getKey()) { | ||||
|             qFather.setRight(delete(qd.getKey(), disc)); | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         return delete; | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) { | ||||
|         int discNext = (_disc + 1) % dims; | ||||
|         if (_disc == disc) { | ||||
|             KDNode child = node.getLeft(); | ||||
|             if (child != null) { | ||||
|                 return max(child, disc, discNext); | ||||
|             } | ||||
|         } else if (node.getLeft() != null || node.getRight() != null) { | ||||
|             Pair<KDNode, Integer> left = null, right = null; | ||||
|             if (node.getLeft() != null) | ||||
|                 left = max(node.getLeft(), disc, discNext); | ||||
|             if (node.getRight() != null) | ||||
|                 right = max(node.getRight(), disc, discNext); | ||||
|             if (left != null && right != null) { | ||||
|                 double pointLeft = left.getKey().getPoint().getDouble(disc); | ||||
|                 double pointRight = right.getKey().getPoint().getDouble(disc); | ||||
|                 if (pointLeft > pointRight) | ||||
|                     return left; | ||||
|                 else | ||||
|                     return right; | ||||
|             } else if (left != null) | ||||
|                 return left; | ||||
|             else | ||||
|                 return right; | ||||
|         } | ||||
| 
 | ||||
|         return Pair.of(node, _disc); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) { | ||||
|         int discNext = (_disc + 1) % dims; | ||||
|         if (_disc == disc) { | ||||
|             KDNode child = node.getLeft(); | ||||
|             if (child != null) { | ||||
|                 return min(child, disc, discNext); | ||||
|             } | ||||
|         } else if (node.getLeft() != null || node.getRight() != null) { | ||||
|             Pair<KDNode, Integer> left = null, right = null; | ||||
|             if (node.getLeft() != null) | ||||
|                 left = min(node.getLeft(), disc, discNext); | ||||
|             if (node.getRight() != null) | ||||
|                 right = min(node.getRight(), disc, discNext); | ||||
|             if (left != null && right != null) { | ||||
|                 double pointLeft = left.getKey().getPoint().getDouble(disc); | ||||
|                 double pointRight = right.getKey().getPoint().getDouble(disc); | ||||
|                 if (pointLeft < pointRight) | ||||
|                     return left; | ||||
|                 else | ||||
|                     return right; | ||||
|             } else if (left != null) | ||||
|                 return left; | ||||
|             else | ||||
|                 return right; | ||||
|         } | ||||
| 
 | ||||
|         return Pair.of(node, _disc); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * The number of elements in the tree | ||||
|      * @return the number of elements in the tree | ||||
|      */ | ||||
|     public int size() { | ||||
|         return size; | ||||
|     } | ||||
| 
 | ||||
|     private int successor(KDNode node, INDArray point, int disc) { | ||||
|         for (int i = disc; i < dims; i++) { | ||||
|             double pointI = point.getDouble(i); | ||||
|             double nodePointI = node.getPoint().getDouble(i); | ||||
|             if (pointI < nodePointI) | ||||
|                 return LESS; | ||||
|             else if (pointI > nodePointI) | ||||
|                 return GREATER; | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         throw new IllegalStateException("Point is equal!"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     private static class KDNode { | ||||
|         private INDArray point; | ||||
|         private KDNode left, right, parent; | ||||
| 
 | ||||
|         public KDNode(INDArray point) { | ||||
|             this.point = point; | ||||
|         } | ||||
| 
 | ||||
|         public INDArray getPoint() { | ||||
|             return point; | ||||
|         } | ||||
| 
 | ||||
|         public KDNode getLeft() { | ||||
|             return left; | ||||
|         } | ||||
| 
 | ||||
|         public void setLeft(KDNode left) { | ||||
|             this.left = left; | ||||
|         } | ||||
| 
 | ||||
|         public KDNode getRight() { | ||||
|             return right; | ||||
|         } | ||||
| 
 | ||||
|         public void setRight(KDNode right) { | ||||
|             this.right = right; | ||||
|         } | ||||
| 
 | ||||
|         public KDNode getParent() { | ||||
|             return parent; | ||||
|         } | ||||
| 
 | ||||
|         public void setParent(KDNode parent) { | ||||
|             this.parent = parent; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,109 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.kmeans; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategy; | ||||
| import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy; | ||||
| 
 | ||||
| 
 | ||||
| public class KMeansClustering extends BaseClusteringAlgorithm { | ||||
| 
 | ||||
|     private static final long serialVersionUID = 8476951388145944776L; | ||||
|     private static final double VARIATION_TOLERANCE= 1e-4; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusteringStrategy | ||||
|      */ | ||||
|     protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) { | ||||
|         super(clusteringStrategy, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Setup a kmeans instance | ||||
|      * @param clusterCount the number of clusters | ||||
|      * @param maxIterationCount the max number of iterations | ||||
|      *                          to run kmeans | ||||
|      * @param distanceFunction the distance function to use for grouping | ||||
|      * @return | ||||
|      */ | ||||
|     public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, | ||||
|                     boolean inverse, boolean useKMeansPlusPlus) { | ||||
|         ClusteringStrategy clusteringStrategy = | ||||
|                         FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse); | ||||
|         clusteringStrategy.endWhenIterationCountEquals(maxIterationCount); | ||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterCount | ||||
|      * @param minDistributionVariationRate | ||||
|      * @param distanceFunction | ||||
|      * @param allowEmptyClusters | ||||
|      * @return | ||||
|      */ | ||||
|     public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, | ||||
|                     boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) { | ||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse) | ||||
|                         .endWhenDistributionVariationRateLessThan(minDistributionVariationRate); | ||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Setup a kmeans instance | ||||
|      * @param clusterCount the number of clusters | ||||
|      * @param maxIterationCount the max number of iterations | ||||
|      *                          to run kmeans | ||||
|      * @param distanceFunction the distance function to use for grouping | ||||
|      * @return | ||||
|      */ | ||||
|     public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) { | ||||
|         return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterCount | ||||
|      * @param minDistributionVariationRate | ||||
|      * @param distanceFunction | ||||
|      * @param allowEmptyClusters | ||||
|      * @return | ||||
|      */ | ||||
|     public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, | ||||
|                     boolean allowEmptyClusters, boolean useKMeansPlusPlus) { | ||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); | ||||
|         clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); | ||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
|     public static KMeansClustering setup(int clusterCount, Distance distanceFunction, | ||||
|                                          boolean allowEmptyClusters, boolean useKMeansPlusPlus) { | ||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); | ||||
|         clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE); | ||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,88 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.lsh; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| public interface LSH { | ||||
| 
 | ||||
|     /** | ||||
|      * Returns an instance of the distance measure associated to the LSH family of this implementation. | ||||
|      * Beware, hashing families and their amplification constructs are distance-specific. | ||||
|      */ | ||||
|      String getDistanceMeasure(); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction | ||||
|      * | ||||
|      * denoting hashLength by h, | ||||
|      * amplifies a (d1, d2, p1, p2) hash family into a | ||||
|      *                   (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h) | ||||
|      * | ||||
|      * @return the length of the hash in the AND construction used by this index | ||||
|      */ | ||||
|      int getHashLength(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * denoting numTables by n, | ||||
|      * amplifies a (d1, d2, p1, p2) hash family into a | ||||
|      *                   (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n) | ||||
|      * | ||||
|      * @return the # of hash tables in the OR construction used by this index | ||||
|      */ | ||||
|      int getNumTables(); | ||||
| 
 | ||||
|     /** | ||||
|      * @return The dimension of the index vectors and queries | ||||
|      */ | ||||
|      int getInDimension(); | ||||
| 
 | ||||
|     /** | ||||
|      * Populates the index with data vectors. | ||||
|      * @param data the vectors to index | ||||
|      */ | ||||
|      void makeIndex(INDArray data); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the set of all vectors that could approximately be considered negihbors of the query, | ||||
|      * without selection on the basis of distance or number of neighbors. | ||||
|      * @param query a  vector to find neighbors for | ||||
|      * @return its approximate neighbors, unfiltered | ||||
|      */ | ||||
|      INDArray bucket(INDArray query); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the approximate neighbors within a distance bound. | ||||
|      * @param query a vector to find neighbors for | ||||
|      * @param maxRange the maximum distance between results and the query | ||||
|      * @return approximate neighbors within the distance bounds | ||||
|      */ | ||||
|      INDArray search(INDArray query, double maxRange); | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the approximate neighbors within a k-closest bound | ||||
|      * @param query a vector to find neighbors for | ||||
|      * @param k the maximum number of closest neighbors to return | ||||
|      * @return at most k neighbors of the query, ordered by increasing distance | ||||
|      */ | ||||
|      INDArray search(INDArray query, int k); | ||||
| } | ||||
| @ -1,227 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.lsh; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import lombok.val; | ||||
| import org.nd4j.common.base.Preconditions; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; | ||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; | ||||
| import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; | ||||
| import org.nd4j.linalg.api.rng.Random; | ||||
| import org.nd4j.linalg.exception.ND4JIllegalStateException; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.indexing.BooleanIndexing; | ||||
| import org.nd4j.linalg.indexing.conditions.Conditions; | ||||
| import org.nd4j.linalg.ops.transforms.Transforms; | ||||
| 
 | ||||
| import java.util.Arrays; | ||||
| 
 | ||||
| 
 | ||||
| public class RandomProjectionLSH implements LSH { | ||||
| 
 | ||||
|     @Override | ||||
|     public String getDistanceMeasure(){ | ||||
|         return "cosinedistance"; | ||||
|     } | ||||
| 
 | ||||
|     @Getter private int hashLength; | ||||
| 
 | ||||
|     @Getter private int numTables; | ||||
| 
 | ||||
|     @Getter private int inDimension; | ||||
| 
 | ||||
| 
 | ||||
|     @Getter private double radius; | ||||
| 
 | ||||
|     INDArray randomProjection; | ||||
| 
 | ||||
|     INDArray index; | ||||
| 
 | ||||
|     INDArray indexData; | ||||
| 
 | ||||
| 
 | ||||
|     private INDArray gaussianRandomMatrix(int[] shape, Random rng){ | ||||
|         INDArray res = Nd4j.create(shape); | ||||
| 
 | ||||
|         GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0])); | ||||
| 
 | ||||
|         Nd4j.getExecutioner().exec(op1, rng); | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){ | ||||
|         this(hashLength, numTables, inDimension, radius, Nd4j.getRandom()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Creates a locality-sensitive hashing index for the cosine distance, | ||||
|      * a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification | ||||
|      * | ||||
|      * @param hashLength the length of the compared hash in an AND construction, | ||||
|      * @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple | ||||
|      *                  probes of Panigraphi (op. cit). | ||||
|      * @param inDimension the dimendionality of the points being indexed | ||||
|      * @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction | ||||
|      * @param rng a Random object to draw samples from | ||||
|      */ | ||||
|     public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){ | ||||
|         this.hashLength = hashLength; | ||||
|         this.numTables = numTables; | ||||
|         this.inDimension = inDimension; | ||||
|         this.radius = radius; | ||||
|         randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * This picks uniformaly distributed random points on the unit of a sphere using the method of: | ||||
|      * | ||||
|      * An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere | ||||
|      * JS Hicks, RF Wheeling - Communications of the ACM, 1959 | ||||
|      * @param data a query to generate multiple probes for | ||||
|      * @return `numTables` | ||||
|      */ | ||||
|     public INDArray entropy(INDArray data){ | ||||
| 
 | ||||
|         INDArray data2 = | ||||
|                     Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius)); | ||||
| 
 | ||||
|         INDArray norms = Nd4j.norm2(data2.dup(), -1); | ||||
| 
 | ||||
|         Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms); | ||||
| 
 | ||||
|         data2.diviColumnVector(norms); | ||||
|         data2.addiRowVector(data); | ||||
|         return data2; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Returns hash values for a particular query | ||||
|      * @param data a query vector | ||||
|      * @return its hashed value | ||||
|      */ | ||||
|     public INDArray hash(INDArray data) { | ||||
|         if (data.shape()[1] != inDimension){ | ||||
|             throw new ND4JIllegalStateException( | ||||
|                     String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d", | ||||
|                             Arrays.toString(data.shape()), inDimension)); | ||||
|         } | ||||
|         INDArray projected = data.mmul(randomProjection); | ||||
|         INDArray res = Nd4j.getExecutioner().exec(new Sign(projected)); | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it. | ||||
|      * @param data the vectors to index | ||||
|      */ | ||||
|     @Override | ||||
|     public void makeIndex(INDArray data) { | ||||
|         index = hash(data); | ||||
|         indexData = data; | ||||
|     } | ||||
| 
 | ||||
|     // data elements in the same bucket as the query, without entropy | ||||
|     INDArray rawBucketOf(INDArray query){ | ||||
|         INDArray pattern = hash(query); | ||||
| 
 | ||||
|         INDArray res = Nd4j.zeros(DataType.BOOL, index.shape()); | ||||
|         Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1)); | ||||
|         return res.castTo(Nd4j.defaultFloatingPointType()).min(-1); | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray bucket(INDArray query) { | ||||
|         INDArray queryRes = rawBucketOf(query); | ||||
| 
 | ||||
|         if(numTables > 1) { | ||||
|             INDArray entropyQueries = entropy(query); | ||||
| 
 | ||||
|             // loop, addi + conditionalreplace -> poor man's OR function | ||||
|             for (int i = 0; i < numTables; i++) { | ||||
|                 INDArray row = entropyQueries.getRow(i, true); | ||||
|                 queryRes.addi(rawBucketOf(row)); | ||||
|             } | ||||
|             BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0)); | ||||
|         } | ||||
| 
 | ||||
|         return queryRes; | ||||
|     } | ||||
| 
 | ||||
|     // data elements in the same entropy bucket as the query, | ||||
|     INDArray bucketData(INDArray query){ | ||||
|         INDArray mask = bucket(query); | ||||
|         int nRes = mask.sum(0).getInt(0); | ||||
|         INDArray res = Nd4j.create(new int[] {nRes, inDimension}); | ||||
|         int j = 0; | ||||
|         for (int i = 0; i < nRes; i++){ | ||||
|             while (mask.getInt(j) == 0 && j < mask.length() - 1) { | ||||
|                 j += 1; | ||||
|             } | ||||
|             if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j)); | ||||
|             j += 1; | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray search(INDArray query, double maxRange) { | ||||
|         if (maxRange < 0) | ||||
|             throw new IllegalArgumentException("ANN search should have a positive maximum search radius"); | ||||
| 
 | ||||
|         INDArray bucketData = bucketData(query); | ||||
|         INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); | ||||
|         INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); | ||||
| 
 | ||||
|         INDArray shuffleIndexes = idxs[0]; | ||||
|         INDArray sortedDistances = idxs[1]; | ||||
|         int accepted = 0; | ||||
|         while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1; | ||||
| 
 | ||||
|         INDArray res = Nd4j.create(new int[] {accepted, inDimension}); | ||||
|         for(int i = 0; i < accepted; i++){ | ||||
|             res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public INDArray search(INDArray query, int k) { | ||||
|         if (k < 1) | ||||
|             throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor"); | ||||
| 
 | ||||
|         INDArray bucketData = bucketData(query); | ||||
|         INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); | ||||
|         INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); | ||||
| 
 | ||||
|         INDArray shuffleIndexes = idxs[0]; | ||||
|         INDArray sortedDistances = idxs[1]; | ||||
|         val accepted = Math.min(k, sortedDistances.shape()[1]); | ||||
| 
 | ||||
|         INDArray res = Nd4j.create(accepted, inDimension); | ||||
|         for(int i = 0; i < accepted; i++){ | ||||
|             res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); | ||||
|         } | ||||
|         return res; | ||||
|     } | ||||
| } | ||||
| @ -1,38 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.optimisation; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.AllArgsConstructor; | ||||
| import lombok.Data; | ||||
| import lombok.NoArgsConstructor; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| @AllArgsConstructor | ||||
| public class ClusteringOptimization implements Serializable { | ||||
| 
 | ||||
|     private ClusteringOptimizationType type; | ||||
|     private double value; | ||||
| 
 | ||||
| } | ||||
| @ -1,28 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.optimisation; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| public enum ClusteringOptimizationType { | ||||
|     MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT | ||||
| } | ||||
| @ -1,115 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.quadtree; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| public class Cell implements Serializable { | ||||
|     private double x, y, hw, hh; | ||||
| 
 | ||||
|     public Cell(double x, double y, double hw, double hh) { | ||||
|         this.x = x; | ||||
|         this.y = y; | ||||
|         this.hw = hw; | ||||
|         this.hh = hh; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Whether the given point is contained | ||||
|      * within this cell | ||||
|      * @param point the point to check | ||||
|      * @return true if the point is contained, false otherwise | ||||
|      */ | ||||
|     public boolean containsPoint(INDArray point) { | ||||
|         double first = point.getDouble(0), second = point.getDouble(1); | ||||
|         return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean equals(Object o) { | ||||
|         if (this == o) | ||||
|             return true; | ||||
|         if (!(o instanceof Cell)) | ||||
|             return false; | ||||
| 
 | ||||
|         Cell cell = (Cell) o; | ||||
| 
 | ||||
|         if (Double.compare(cell.hh, hh) != 0) | ||||
|             return false; | ||||
|         if (Double.compare(cell.hw, hw) != 0) | ||||
|             return false; | ||||
|         if (Double.compare(cell.x, x) != 0) | ||||
|             return false; | ||||
|         return Double.compare(cell.y, y) == 0; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         int result; | ||||
|         long temp; | ||||
|         temp = Double.doubleToLongBits(x); | ||||
|         result = (int) (temp ^ (temp >>> 32)); | ||||
|         temp = Double.doubleToLongBits(y); | ||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); | ||||
|         temp = Double.doubleToLongBits(hw); | ||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); | ||||
|         temp = Double.doubleToLongBits(hh); | ||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     public double getX() { | ||||
|         return x; | ||||
|     } | ||||
| 
 | ||||
|     public void setX(double x) { | ||||
|         this.x = x; | ||||
|     } | ||||
| 
 | ||||
|     public double getY() { | ||||
|         return y; | ||||
|     } | ||||
| 
 | ||||
|     public void setY(double y) { | ||||
|         this.y = y; | ||||
|     } | ||||
| 
 | ||||
|     public double getHw() { | ||||
|         return hw; | ||||
|     } | ||||
| 
 | ||||
|     public void setHw(double hw) { | ||||
|         this.hw = hw; | ||||
|     } | ||||
| 
 | ||||
|     public double getHh() { | ||||
|         return hh; | ||||
|     } | ||||
| 
 | ||||
|     public void setHh(double hh) { | ||||
|         this.hh = hh; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,383 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.quadtree; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.util.concurrent.AtomicDouble; | ||||
| import org.apache.commons.math3.util.FastMath; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| import static java.lang.Math.max; | ||||
| 
 | ||||
| public class QuadTree implements Serializable { | ||||
|     private QuadTree parent, northWest, northEast, southWest, southEast; | ||||
|     private boolean isLeaf = true; | ||||
|     private int size, cumSize; | ||||
|     private Cell boundary; | ||||
|     static final int QT_NO_DIMS = 2; | ||||
|     static final int QT_NODE_CAPACITY = 1; | ||||
|     private INDArray buf = Nd4j.create(QT_NO_DIMS); | ||||
|     private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS); | ||||
|     private int[] index = new int[QT_NODE_CAPACITY]; | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Pass in a matrix | ||||
|      * @param data | ||||
|      */ | ||||
|     public QuadTree(INDArray data) { | ||||
|         INDArray meanY = data.mean(0); | ||||
|         INDArray minY = data.min(0); | ||||
|         INDArray maxY = data.max(0); | ||||
|         init(data, meanY.getDouble(0), meanY.getDouble(1), | ||||
|                         max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0)) | ||||
|                                         + Nd4j.EPS_THRESHOLD, | ||||
|                         max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1)) | ||||
|                                         + Nd4j.EPS_THRESHOLD); | ||||
|         fill(); | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree(QuadTree parent, INDArray data, Cell boundary) { | ||||
|         this.parent = parent; | ||||
|         this.boundary = boundary; | ||||
|         this.data = data; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree(Cell boundary) { | ||||
|         this.boundary = boundary; | ||||
|     } | ||||
| 
 | ||||
|     private void init(INDArray data, double x, double y, double hw, double hh) { | ||||
|         boundary = new Cell(x, y, hw, hh); | ||||
|         this.data = data; | ||||
|     } | ||||
| 
 | ||||
|     private void fill() { | ||||
|         for (int i = 0; i < data.rows(); i++) | ||||
|             insert(i); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns the cell of this element | ||||
|      * | ||||
|      * @param coordinates | ||||
|      * @return | ||||
|      */ | ||||
|     protected QuadTree findIndex(INDArray coordinates) { | ||||
| 
 | ||||
|         // Compute the sector for the coordinates | ||||
|         boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2)); | ||||
|         boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2)); | ||||
| 
 | ||||
|         // top left | ||||
|         QuadTree index = getNorthWest(); | ||||
|         if (left) { | ||||
|             // left side | ||||
|             if (!top) { | ||||
|                 // bottom left | ||||
|                 index = getSouthWest(); | ||||
|             } | ||||
|         } else { | ||||
|             // right side | ||||
|             if (top) { | ||||
|                 // top right | ||||
|                 index = getNorthEast(); | ||||
|             } else { | ||||
|                 // bottom right | ||||
|                 index = getSouthEast(); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         return index; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Insert an index of the data in to the tree | ||||
|      * @param newIndex the index to insert in to the tree | ||||
|      * @return whether the index was inserted or not | ||||
|      */ | ||||
|     public boolean insert(int newIndex) { | ||||
|         // Ignore objects which do not belong in this quad tree | ||||
|         INDArray point = data.slice(newIndex); | ||||
|         if (!boundary.containsPoint(point)) | ||||
|             return false; | ||||
| 
 | ||||
|         cumSize++; | ||||
|         double mult1 = (double) (cumSize - 1) / (double) cumSize; | ||||
|         double mult2 = 1.0 / (double) cumSize; | ||||
| 
 | ||||
|         centerOfMass.muli(mult1); | ||||
|         centerOfMass.addi(point.mul(mult2)); | ||||
| 
 | ||||
|         // If there is space in this quad tree and it is a leaf, add the object here | ||||
|         if (isLeaf() && size < QT_NODE_CAPACITY) { | ||||
|             index[size] = newIndex; | ||||
|             size++; | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         //duplicate point | ||||
|         if (size > 0) { | ||||
|             for (int i = 0; i < size; i++) { | ||||
|                 INDArray compPoint = data.slice(index[i]); | ||||
|                 if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1)) | ||||
|                     return true; | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         // If this Node has already been subdivided just add the elements to the | ||||
|         // appropriate cell | ||||
|         if (!isLeaf()) { | ||||
|             QuadTree index = findIndex(point); | ||||
|             index.insert(newIndex); | ||||
|             return true; | ||||
|         } | ||||
| 
 | ||||
|         if (isLeaf()) | ||||
|             subDivide(); | ||||
| 
 | ||||
|         boolean ret = insertIntoOneOf(newIndex); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     private boolean insertIntoOneOf(int index) { | ||||
|         boolean success = false; | ||||
|         success = northWest.insert(index); | ||||
|         if (!success) | ||||
|             success = northEast.insert(index); | ||||
|         if (!success) | ||||
|             success = southWest.insert(index); | ||||
|         if (!success) | ||||
|             success = southEast.insert(index); | ||||
|         return success; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Returns whether the tree is consistent or not | ||||
|      * @return whether the tree is consistent or not | ||||
|      */ | ||||
|     public boolean isCorrect() { | ||||
| 
 | ||||
|         for (int n = 0; n < size; n++) { | ||||
|             INDArray point = data.slice(index[n]); | ||||
|             if (!boundary.containsPoint(point)) | ||||
|                 return false; | ||||
|         } | ||||
| 
 | ||||
|         return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect() | ||||
|                         && southEast.isCorrect(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      *  Create four children | ||||
|      *  which fully divide this cell | ||||
|      *  into four quads of equal area | ||||
|      */ | ||||
|     public void subDivide() { | ||||
|         northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), | ||||
|                         boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); | ||||
|         northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), | ||||
|                         boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); | ||||
|         southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), | ||||
|                         boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); | ||||
|         southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), | ||||
|                         boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Compute non edge forces using barnes hut | ||||
|      * @param pointIndex | ||||
|      * @param theta | ||||
|      * @param negativeForce | ||||
|      * @param sumQ | ||||
|      */ | ||||
|     public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { | ||||
|         // Make sure that we spend no time on empty nodes or self-interactions | ||||
|         if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) | ||||
|             return; | ||||
| 
 | ||||
| 
 | ||||
|         // Compute distance between point and center-of-mass | ||||
|         buf.assign(data.slice(pointIndex)).subi(centerOfMass); | ||||
| 
 | ||||
|         double D = Nd4j.getBlasWrapper().dot(buf, buf); | ||||
| 
 | ||||
|         // Check whether we can use this node as a "summary" | ||||
|         if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) { | ||||
| 
 | ||||
|             // Compute and add t-SNE force between point and current node | ||||
|             double Q = 1.0 / (1.0 + D); | ||||
|             double mult = cumSize * Q; | ||||
|             sumQ.addAndGet(mult); | ||||
|             mult *= Q; | ||||
|             negativeForce.addi(buf.mul(mult)); | ||||
| 
 | ||||
|         } else { | ||||
| 
 | ||||
|             // Recursively apply Barnes-Hut to children | ||||
|             northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); | ||||
|             northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); | ||||
|             southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); | ||||
|             southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param rowP a vector | ||||
|      * @param colP | ||||
|      * @param valP | ||||
|      * @param N | ||||
|      * @param posF | ||||
|      */ | ||||
|     public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { | ||||
|         if (!rowP.isVector()) | ||||
|             throw new IllegalArgumentException("RowP must be a vector"); | ||||
| 
 | ||||
|         // Loop over all edges in the graph | ||||
|         double D; | ||||
|         for (int n = 0; n < N; n++) { | ||||
|             for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { | ||||
| 
 | ||||
|                 // Compute pairwise distance and Q-value | ||||
|                 buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i))); | ||||
| 
 | ||||
|                 D = Nd4j.getBlasWrapper().dot(buf, buf); | ||||
|                 D = valP.getDouble(i) / D; | ||||
| 
 | ||||
|                 // Sum positive force | ||||
|                 posF.slice(n).addi(buf.mul(D)); | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * The depth of the node | ||||
|      * @return the depth of the node | ||||
|      */ | ||||
|     public int depth() { | ||||
|         if (isLeaf()) | ||||
|             return 1; | ||||
|         return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth())); | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getCenterOfMass() { | ||||
|         return centerOfMass; | ||||
|     } | ||||
| 
 | ||||
|     public void setCenterOfMass(INDArray centerOfMass) { | ||||
|         this.centerOfMass = centerOfMass; | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree getParent() { | ||||
|         return parent; | ||||
|     } | ||||
| 
 | ||||
|     public void setParent(QuadTree parent) { | ||||
|         this.parent = parent; | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree getNorthWest() { | ||||
|         return northWest; | ||||
|     } | ||||
| 
 | ||||
|     public void setNorthWest(QuadTree northWest) { | ||||
|         this.northWest = northWest; | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree getNorthEast() { | ||||
|         return northEast; | ||||
|     } | ||||
| 
 | ||||
|     public void setNorthEast(QuadTree northEast) { | ||||
|         this.northEast = northEast; | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree getSouthWest() { | ||||
|         return southWest; | ||||
|     } | ||||
| 
 | ||||
|     public void setSouthWest(QuadTree southWest) { | ||||
|         this.southWest = southWest; | ||||
|     } | ||||
| 
 | ||||
|     public QuadTree getSouthEast() { | ||||
|         return southEast; | ||||
|     } | ||||
| 
 | ||||
|     public void setSouthEast(QuadTree southEast) { | ||||
|         this.southEast = southEast; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isLeaf() { | ||||
|         return isLeaf; | ||||
|     } | ||||
| 
 | ||||
|     public void setLeaf(boolean isLeaf) { | ||||
|         this.isLeaf = isLeaf; | ||||
|     } | ||||
| 
 | ||||
|     public int getSize() { | ||||
|         return size; | ||||
|     } | ||||
| 
 | ||||
|     public void setSize(int size) { | ||||
|         this.size = size; | ||||
|     } | ||||
| 
 | ||||
|     public int getCumSize() { | ||||
|         return cumSize; | ||||
|     } | ||||
| 
 | ||||
|     public void setCumSize(int cumSize) { | ||||
|         this.cumSize = cumSize; | ||||
|     } | ||||
| 
 | ||||
|     public Cell getBoundary() { | ||||
|         return boundary; | ||||
|     } | ||||
| 
 | ||||
|     public void setBoundary(Cell boundary) { | ||||
|         this.boundary = boundary; | ||||
|     } | ||||
| } | ||||
| @ -1,104 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.randomprojection; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @Data | ||||
| public class RPForest { | ||||
| 
 | ||||
|     private int numTrees; | ||||
|     private List<RPTree> trees; | ||||
|     private INDArray data; | ||||
|     private int maxSize = 1000; | ||||
|     private String similarityFunction; | ||||
| 
 | ||||
|     /** | ||||
|      * Create the rp forest with the specified number of trees | ||||
|      * @param numTrees the number of trees in the forest | ||||
|      * @param maxSize the max size of each tree | ||||
|      * @param similarityFunction the distance function to use | ||||
|      */ | ||||
|     public RPForest(int numTrees,int maxSize,String similarityFunction) { | ||||
|         this.numTrees = numTrees; | ||||
|         this.maxSize = maxSize; | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         trees = new ArrayList<>(numTrees); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Build the trees from the given dataset | ||||
|      * @param x the input dataset (should be a 2d matrix) | ||||
|      */ | ||||
|     public void fit(INDArray x) { | ||||
|         this.data = x; | ||||
|         for(int i = 0; i < numTrees; i++) { | ||||
|             RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction); | ||||
|             tree.buildTree(x); | ||||
|             trees.add(tree); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get all candidates relative to a specific datapoint. | ||||
|      * @param input | ||||
|      * @return | ||||
|      */ | ||||
|     public INDArray getAllCandidates(INDArray input) { | ||||
|         return RPUtils.getAllCandidates(input,trees,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Query results up to length n | ||||
|      * nearest neighbors | ||||
|      * @param toQuery the query item | ||||
|      * @param n the number of nearest neighbors for the given data point | ||||
|      * @return the indices for the nearest neighbors | ||||
|      */ | ||||
|     public INDArray queryAll(INDArray toQuery,int n) { | ||||
|         return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Query all with the distances | ||||
|      * sorted by index | ||||
|      * @param query the query vector | ||||
|      * @param numResults the number of results to return | ||||
|      * @return a list of samples | ||||
|      */ | ||||
|     public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) { | ||||
|         return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,57 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.randomprojection; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| @Data | ||||
| public class RPHyperPlanes { | ||||
|     private int dim; | ||||
|     private INDArray wholeHyperPlane; | ||||
| 
 | ||||
|     public RPHyperPlanes(int dim) { | ||||
|         this.dim = dim; | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getHyperPlaneAt(int depth) { | ||||
|         if(wholeHyperPlane.isVector()) | ||||
|             return wholeHyperPlane; | ||||
|         return wholeHyperPlane.slice(depth); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Add a new random element to the hyper plane. | ||||
|      */ | ||||
|     public void addRandomHyperPlane() { | ||||
|         INDArray newPlane = Nd4j.randn(new int[] {1,dim}); | ||||
|         newPlane.divi(newPlane.normmaxNumber()); | ||||
|         if(wholeHyperPlane == null) | ||||
|             wholeHyperPlane = newPlane; | ||||
|         else { | ||||
|             wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,48 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.randomprojection; | ||||
| 
 | ||||
| 
 | ||||
| import lombok.Data; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.Future; | ||||
| 
 | ||||
| @Data | ||||
| public class RPNode { | ||||
|     private int depth; | ||||
|     private RPNode left,right; | ||||
|     private Future<RPNode> leftFuture,rightFuture; | ||||
|     private List<Integer> indices; | ||||
|     private double median; | ||||
|     private RPTree tree; | ||||
| 
 | ||||
| 
 | ||||
|     public RPNode(RPTree tree,int depth) { | ||||
|         this.depth = depth; | ||||
|         this.tree = tree; | ||||
|         indices = new ArrayList<>(); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,130 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.randomprojection; | ||||
| 
 | ||||
| import lombok.Builder; | ||||
| import lombok.Data; | ||||
| import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; | ||||
| import org.nd4j.linalg.api.memory.enums.*; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| import java.util.concurrent.ExecutorService; | ||||
| 
 | ||||
| @Data | ||||
| public class RPTree { | ||||
|     private RPNode root; | ||||
|     private RPHyperPlanes rpHyperPlanes; | ||||
|     private int dim; | ||||
|     //also knows as leave size | ||||
|     private int maxSize; | ||||
|     private INDArray X; | ||||
|     private String similarityFunction = "euclidean"; | ||||
|     private WorkspaceConfiguration workspaceConfiguration; | ||||
|     private ExecutorService searchExecutor; | ||||
|     private int searchWorkers; | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param dim the dimension of the vectors | ||||
|      * @param maxSize the max size of the leaves | ||||
|      * | ||||
|      */ | ||||
|     @Builder | ||||
|     public RPTree(int dim, int maxSize,String similarityFunction) { | ||||
|         this.dim = dim; | ||||
|         this.maxSize = maxSize; | ||||
|         rpHyperPlanes = new RPHyperPlanes(dim); | ||||
|         root = new RPNode(this,0); | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) | ||||
|                 .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) | ||||
|                 .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) | ||||
|                 .policySpill(SpillPolicy.REALLOCATE).build(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param dim the dimension of the vectors | ||||
|      * @param maxSize the max size of the leaves | ||||
|      * | ||||
|      */ | ||||
|     public RPTree(int dim, int maxSize) { | ||||
|        this(dim,maxSize,"euclidean"); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      *  Build the tree with the given input data | ||||
|      * @param x | ||||
|      */ | ||||
| 
 | ||||
|     public void buildTree(INDArray x) { | ||||
|         this.X = x; | ||||
|         for(int i = 0; i < x.rows(); i++) { | ||||
|             root.getIndices().add(i); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         RPUtils.buildTree(this,root,rpHyperPlanes, | ||||
|                 x,maxSize,0,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public void addNodeAtIndex(int idx,INDArray toAdd) { | ||||
|         RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction); | ||||
|         query.getIndices().add(idx); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public List<RPNode> getLeaves() { | ||||
|         List<RPNode> nodes = new ArrayList<>(); | ||||
|         RPUtils.scanForLeaves(nodes,getRoot()); | ||||
|         return nodes; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Query all with the distances | ||||
|      * sorted by index | ||||
|      * @param query the query vector | ||||
|      * @param numResults the number of results to return | ||||
|      * @return a list of samples | ||||
|      */ | ||||
|     public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) { | ||||
|             return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
|     public INDArray query(INDArray query,int numResults) { | ||||
|         return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction); | ||||
|     } | ||||
| 
 | ||||
|     public List<Integer> getCandidates(INDArray target) { | ||||
|         return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,481 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.randomprojection; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.primitives.Doubles; | ||||
| import lombok.val; | ||||
| import org.nd4j.autodiff.functions.DifferentialFunction; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.ReduceOp; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; | ||||
| import org.nd4j.linalg.exception.ND4JIllegalArgumentException; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| 
 | ||||
| import java.util.*; | ||||
| 
 | ||||
| public class RPUtils { | ||||
| 
 | ||||
| 
 | ||||
|     private static ThreadLocal<Map<String,DifferentialFunction>> functionInstances = new ThreadLocal<>(); | ||||
| 
 | ||||
|     public static <T extends DifferentialFunction> DifferentialFunction getOp(String name, | ||||
|                                                                               INDArray x, | ||||
|                                                                               INDArray y, | ||||
|                                                                               INDArray result) { | ||||
|         Map<String,DifferentialFunction> ops = functionInstances.get(); | ||||
|         if(ops == null) { | ||||
|             ops = new HashMap<>(); | ||||
|             functionInstances.set(ops); | ||||
|         } | ||||
| 
 | ||||
|         boolean allDistances = x.length() != y.length(); | ||||
| 
 | ||||
|         switch(name) { | ||||
|             case "cosinedistance": | ||||
|                 if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances); | ||||
|                     ops.put(name,cosineDistance); | ||||
|                     return cosineDistance; | ||||
|                 } | ||||
|                 else { | ||||
|                     CosineDistance cosineDistance = (CosineDistance) ops.get(name); | ||||
|                     return cosineDistance; | ||||
|                 } | ||||
|             case "cosinesimilarity": | ||||
|                 if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances); | ||||
|                     ops.put(name,cosineSimilarity); | ||||
|                     return cosineSimilarity; | ||||
|                 } | ||||
|                 else { | ||||
|                     CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name); | ||||
|                     cosineSimilarity.setX(x); | ||||
|                     cosineSimilarity.setY(y); | ||||
|                     cosineSimilarity.setZ(result); | ||||
|                     return cosineSimilarity; | ||||
| 
 | ||||
|                 } | ||||
|             case "manhattan": | ||||
|                 if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances); | ||||
|                     ops.put(name,manhattanDistance); | ||||
|                     return manhattanDistance; | ||||
|                 } | ||||
|                 else { | ||||
|                     ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name); | ||||
|                     manhattanDistance.setX(x); | ||||
|                     manhattanDistance.setY(y); | ||||
|                     manhattanDistance.setZ(result); | ||||
|                     return  manhattanDistance; | ||||
|                 } | ||||
|             case "jaccard": | ||||
|                 if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances); | ||||
|                     ops.put(name,jaccardDistance); | ||||
|                     return jaccardDistance; | ||||
|                 } | ||||
|                 else { | ||||
|                     JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name); | ||||
|                     jaccardDistance.setX(x); | ||||
|                     jaccardDistance.setY(y); | ||||
|                     jaccardDistance.setZ(result); | ||||
|                     return jaccardDistance; | ||||
|                 } | ||||
|             case "hamming": | ||||
|                 if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances); | ||||
|                     ops.put(name,hammingDistance); | ||||
|                     return hammingDistance; | ||||
|                 } | ||||
|                 else { | ||||
|                     HammingDistance hammingDistance = (HammingDistance) ops.get(name); | ||||
|                     hammingDistance.setX(x); | ||||
|                     hammingDistance.setY(y); | ||||
|                     hammingDistance.setZ(result); | ||||
|                     return hammingDistance; | ||||
|                 } | ||||
|                 //euclidean | ||||
|             default: | ||||
|                 if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { | ||||
|                     EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances); | ||||
|                     ops.put(name,euclideanDistance); | ||||
|                     return euclideanDistance; | ||||
|                 } | ||||
|                 else { | ||||
|                     EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name); | ||||
|                     euclideanDistance.setX(x); | ||||
|                     euclideanDistance.setY(y); | ||||
|                     euclideanDistance.setZ(result); | ||||
|                     return euclideanDistance; | ||||
|                 } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Query all trees using the given input and data | ||||
|      * @param toQuery the query vector | ||||
|      * @param X the input data to query | ||||
|      * @param trees the trees to query | ||||
|      * @param n the number of results to search for | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @return the indices (in order) in the ndarray | ||||
|      */ | ||||
|     public static List<Pair<Double,Integer>> queryAllWithDistances(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) { | ||||
|         if(trees.isEmpty()) { | ||||
|             throw new ND4JIllegalArgumentException("Trees is empty!"); | ||||
|         } | ||||
| 
 | ||||
|         List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction); | ||||
|         val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); | ||||
|         int numReturns = Math.min(n,sortedCandidates.size()); | ||||
|         List<Pair<Double,Integer>> ret = new ArrayList<>(numReturns); | ||||
|         for(int i = 0; i < numReturns; i++) { | ||||
|             ret.add(sortedCandidates.get(i)); | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Query all trees using the given input and data | ||||
|      * @param toQuery the query vector | ||||
|      * @param X the input data to query | ||||
|      * @param trees the trees to query | ||||
|      * @param n the number of results to search for | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @return the indices (in order) in the ndarray | ||||
|      */ | ||||
|     public static INDArray queryAll(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) { | ||||
|         if(trees.isEmpty()) { | ||||
|             throw new ND4JIllegalArgumentException("Trees is empty!"); | ||||
|         } | ||||
| 
 | ||||
|         List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction); | ||||
|         val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); | ||||
|         int numReturns = Math.min(n,sortedCandidates.size()); | ||||
| 
 | ||||
|         INDArray result = Nd4j.create(numReturns); | ||||
|         for(int i = 0; i < numReturns; i++) { | ||||
|             result.putScalar(i,sortedCandidates.get(i).getSecond()); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Get the sorted distances given the | ||||
|      * query vector, input data, given the list of possible search candidates | ||||
|      * @param x the query vector | ||||
|      * @param X the input data to use | ||||
|      * @param candidates the possible search candidates | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @return the sorted distances | ||||
|      */ | ||||
|     public static List<Pair<Double,Integer>> sortCandidates(INDArray x,INDArray X, | ||||
|                                                             List<Integer> candidates, | ||||
|                                                             String similarityFunction) { | ||||
|         int prevIdx = -1; | ||||
|         List<Pair<Double,Integer>> ret = new ArrayList<>(); | ||||
|         for(int i = 0; i < candidates.size(); i++) { | ||||
|             if(candidates.get(i) != prevIdx) { | ||||
|                 ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i))); | ||||
|             } | ||||
| 
 | ||||
|             prevIdx = i; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         Collections.sort(ret, new Comparator<Pair<Double, Integer>>() { | ||||
|             @Override | ||||
|             public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) { | ||||
|                 return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst()); | ||||
|             } | ||||
|         }); | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Get the search candidates as indices given the input | ||||
|      * and similarity function | ||||
|      * @param x the input data to search with | ||||
|      * @param trees the trees to search | ||||
|      * @param similarityFunction the function to use for similarity | ||||
|      * @return the list of indices as the search results | ||||
|      */ | ||||
|     public static INDArray getAllCandidates(INDArray x,List<RPTree> trees,String similarityFunction) { | ||||
|         List<Integer> candidates = getCandidates(x,trees,similarityFunction); | ||||
|         Collections.sort(candidates); | ||||
| 
 | ||||
|         int prevIdx = -1; | ||||
|         int idxCount = 0; | ||||
|         List<Pair<Integer,Integer>> scores = new ArrayList<>(); | ||||
|         for(int i = 0; i < candidates.size(); i++) { | ||||
|             if(candidates.get(i) == prevIdx) { | ||||
|                 idxCount++; | ||||
|             } | ||||
|             else if(prevIdx != -1) { | ||||
|                 scores.add(Pair.of(idxCount,prevIdx)); | ||||
|                 idxCount = 1; | ||||
|             } | ||||
| 
 | ||||
|             prevIdx = i; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         scores.add(Pair.of(idxCount,prevIdx)); | ||||
| 
 | ||||
|         INDArray arr = Nd4j.create(scores.size()); | ||||
|         for(int i = 0; i < scores.size(); i++) { | ||||
|             arr.putScalar(i,scores.get(i).getSecond()); | ||||
|         } | ||||
| 
 | ||||
|         return arr; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Get the search candidates as indices given the input | ||||
|      * and similarity function | ||||
|      * @param x the input data to search with | ||||
|      * @param roots the trees to search | ||||
|      * @param similarityFunction the function to use for similarity | ||||
|      * @return the list of indices as the search results | ||||
|      */ | ||||
|     public static List<Integer> getCandidates(INDArray x,List<RPTree> roots,String similarityFunction) { | ||||
|         Set<Integer> ret = new LinkedHashSet<>(); | ||||
|         for(RPTree tree : roots) { | ||||
|             RPNode root = tree.getRoot(); | ||||
|             RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction); | ||||
|             ret.addAll(query.getIndices()); | ||||
|         } | ||||
| 
 | ||||
|         return new ArrayList<>(ret); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Query the tree starting from the given node | ||||
|      * using the given hyper plane and similarity function | ||||
|      * @param from the node to start from | ||||
|      * @param planes the hyper plane to query | ||||
|      * @param x the input data | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @return the leaf node representing the given query from a | ||||
|      * search in the tree | ||||
|      */ | ||||
|     public static  RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) { | ||||
|         if(from.getLeft() == null && from.getRight() == null) { | ||||
|             return from; | ||||
|         } | ||||
| 
 | ||||
|         INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth()); | ||||
|         double dist = computeDistance(similarityFunction,x,hyperPlane); | ||||
|         if(dist <= from.getMedian()) { | ||||
|             return query(from.getLeft(),planes,x,similarityFunction); | ||||
|         } | ||||
| 
 | ||||
|         else { | ||||
|             return query(from.getRight(),planes,x,similarityFunction); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Compute the distance between 2 vectors | ||||
|      * given a function name. Valid function names: | ||||
|      * euclidean: euclidean distance | ||||
|      * cosinedistance: cosine distance | ||||
|      * cosine similarity: cosine similarity | ||||
|      * manhattan: manhattan distance | ||||
|      * jaccard: jaccard distance | ||||
|      * hamming: hamming distance | ||||
|      * @param function the function to use (default euclidean distance) | ||||
|      * @param x the first vector | ||||
|      * @param y the second vector | ||||
|      * @return the distance between the 2 vectors given the inputs | ||||
|      */ | ||||
|     public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) { | ||||
|         ReduceOp op = (ReduceOp) getOp(function, x, y, result); | ||||
|         op.setDimensions(1); | ||||
|         Nd4j.getExecutioner().exec(op); | ||||
|         return op.z(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
| 
 | ||||
|     /** | ||||
|      * Compute the distance between 2 vectors | ||||
|      * given a function name. Valid function names: | ||||
|      * euclidean: euclidean distance | ||||
|      * cosinedistance: cosine distance | ||||
|      * cosine similarity: cosine similarity | ||||
|      * manhattan: manhattan distance | ||||
|      * jaccard: jaccard distance | ||||
|      * hamming: hamming distance | ||||
|      * @param function the function to use (default euclidean distance) | ||||
|      * @param x the first vector | ||||
|      * @param y the second vector | ||||
|      * @return the distance between the 2 vectors given the inputs | ||||
|      */ | ||||
|     public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) { | ||||
|         ReduceOp op = (ReduceOp) getOp(function, x, y, result); | ||||
|         Nd4j.getExecutioner().exec(op); | ||||
|         return op.z().getDouble(0); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Compute the distance between 2 vectors | ||||
|      * given a function name. Valid function names: | ||||
|      * euclidean: euclidean distance | ||||
|      * cosinedistance: cosine distance | ||||
|      * cosine similarity: cosine similarity | ||||
|      * manhattan: manhattan distance | ||||
|      * jaccard: jaccard distance | ||||
|      * hamming: hamming distance | ||||
|      * @param function the function to use (default euclidean distance) | ||||
|      * @param x the first vector | ||||
|      * @param y the second vector | ||||
|      * @return the distance between the 2 vectors given the inputs | ||||
|      */ | ||||
|     public static double computeDistance(String function,INDArray x,INDArray y) { | ||||
|         return computeDistance(function,x,y,Nd4j.scalar(0.0)); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Initialize the tree given the input parameters | ||||
|      * @param tree the tree to initialize | ||||
|      * @param from the starting node | ||||
|      * @param planes the hyper planes to use (vector space for similarity) | ||||
|      * @param X the input data | ||||
|      * @param maxSize the max number of indices on a given leaf node | ||||
|      * @param depth the current depth of the tree | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      */ | ||||
|     public static void buildTree(RPTree tree, | ||||
|                                  RPNode from, | ||||
|                                  RPHyperPlanes planes, | ||||
|                                  INDArray X, | ||||
|                                  int maxSize, | ||||
|                                  int depth, | ||||
|                                  String similarityFunction) { | ||||
|         if(from.getIndices().size() <= maxSize) { | ||||
|             //slimNode | ||||
|             slimNode(from); | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         List<Double> distances = new ArrayList<>(); | ||||
|         RPNode left = new RPNode(tree,depth + 1); | ||||
|         RPNode right = new RPNode(tree,depth + 1); | ||||
| 
 | ||||
|         if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) { | ||||
|             planes.addRandomHyperPlane(); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         INDArray hyperPlane = planes.getHyperPlaneAt(depth); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|         for(int i = 0; i < from.getIndices().size(); i++) { | ||||
|             double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); | ||||
|             distances.add(cosineSim); | ||||
|         } | ||||
| 
 | ||||
|         Collections.sort(distances); | ||||
|         from.setMedian(distances.get(distances.size() / 2)); | ||||
| 
 | ||||
| 
 | ||||
|         for(int i = 0; i < from.getIndices().size(); i++) { | ||||
|             double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); | ||||
|             if(cosineSim <= from.getMedian()) { | ||||
|                 left.getIndices().add(from.getIndices().get(i)); | ||||
|             } | ||||
|             else { | ||||
|                 right.getIndices().add(from.getIndices().get(i)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         //failed split | ||||
|         if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) { | ||||
|             slimNode(from); | ||||
|             return; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         from.setLeft(left); | ||||
|         from.setRight(right); | ||||
|         slimNode(from); | ||||
| 
 | ||||
| 
 | ||||
|         buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction); | ||||
|         buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Scan for leaves accumulating | ||||
|      * the nodes in the passed in list | ||||
|      * @param nodes the nodes so far | ||||
|      * @param scan the tree to scan | ||||
|      */ | ||||
|     public static void scanForLeaves(List<RPNode> nodes,RPTree scan) { | ||||
|         scanForLeaves(nodes,scan.getRoot()); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Scan for leaves accumulating | ||||
|      * the nodes in the passed in list | ||||
|      * @param nodes the nodes so far | ||||
|      */ | ||||
|     public static void scanForLeaves(List<RPNode> nodes,RPNode current) { | ||||
|         if(current.getLeft() == null && current.getRight() == null) | ||||
|             nodes.add(current); | ||||
|         if(current.getLeft() != null) | ||||
|             scanForLeaves(nodes,current.getLeft()); | ||||
|         if(current.getRight() != null) | ||||
|             scanForLeaves(nodes,current.getRight()); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Prune indices from the given node | ||||
|      * when it's a leaf | ||||
|      * @param node the node to prune | ||||
|      */ | ||||
|     public static void slimNode(RPNode node) { | ||||
|         if(node.getRight() != null && node.getLeft() != null) { | ||||
|             node.getIndices().clear(); | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,87 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.sptree; | ||||
| 
 | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class Cell implements Serializable { | ||||
|     private int dimension; | ||||
|     private INDArray corner, width; | ||||
| 
 | ||||
|     public Cell(int dimension) { | ||||
|         this.dimension = dimension; | ||||
|     } | ||||
| 
 | ||||
|     public double corner(int d) { | ||||
|         return corner.getDouble(d); | ||||
|     } | ||||
| 
 | ||||
|     public double width(int d) { | ||||
|         return width.getDouble(d); | ||||
|     } | ||||
| 
 | ||||
|     public void setCorner(int d, double corner) { | ||||
|         this.corner.putScalar(d, corner); | ||||
|     } | ||||
| 
 | ||||
|     public void setWidth(int d, double width) { | ||||
|         this.width.putScalar(d, width); | ||||
|     } | ||||
| 
 | ||||
|     public void setWidth(INDArray width) { | ||||
|         this.width = width; | ||||
|     } | ||||
| 
 | ||||
|     public void setCorner(INDArray corner) { | ||||
|         this.corner = corner; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public boolean contains(INDArray point) { | ||||
|         INDArray cornerMinusWidth = corner.sub(width); | ||||
|         INDArray cornerPlusWidth = corner.add(width); | ||||
|         for (int d = 0; d < dimension; d++) { | ||||
|             double pointD = point.getDouble(d); | ||||
|             if (cornerMinusWidth.getDouble(d) > pointD) | ||||
|                 return false; | ||||
|             if (cornerPlusWidth.getDouble(d) < pointD) | ||||
|                 return false; | ||||
|         } | ||||
|         return true; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public INDArray width() { | ||||
|         return width; | ||||
|     } | ||||
| 
 | ||||
|     public INDArray corner() { | ||||
|         return corner; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,95 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.sptree; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| public class DataPoint implements Serializable { | ||||
|     private int index; | ||||
|     private INDArray point; | ||||
|     private long d; | ||||
|     private String functionName; | ||||
|     private boolean invert = false; | ||||
| 
 | ||||
| 
 | ||||
|     public DataPoint(int index, INDArray point, boolean invert) { | ||||
|         this(index, point, "euclidean"); | ||||
|         this.invert = invert; | ||||
|     } | ||||
| 
 | ||||
|     public DataPoint(int index, INDArray point, String functionName, boolean invert) { | ||||
|         this.index = index; | ||||
|         this.point = point; | ||||
|         this.functionName = functionName; | ||||
|         this.d = point.length(); | ||||
|         this.invert = invert; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public DataPoint(int index, INDArray point) { | ||||
|         this(index, point, false); | ||||
|     } | ||||
| 
 | ||||
|     public DataPoint(int index, INDArray point, String functionName) { | ||||
|         this(index, point, functionName, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Euclidean distance | ||||
|      * @param point the distance from this point to the given point | ||||
|      * @return the distance between the two points | ||||
|      */ | ||||
|     public float distance(DataPoint point) { | ||||
|         switch (functionName) { | ||||
|             case "euclidean": | ||||
|                 float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) | ||||
|                                 .getFinalResult().floatValue(); | ||||
|                 return invert ? -ret : ret; | ||||
| 
 | ||||
|             case "cosinesimilarity": | ||||
|                 float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point)) | ||||
|                                 .getFinalResult().floatValue(); | ||||
|                 return invert ? -ret2 : ret2; | ||||
| 
 | ||||
|             case "manhattan": | ||||
|                 float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point)) | ||||
|                                 .getFinalResult().floatValue(); | ||||
|                 return invert ? -ret3 : ret3; | ||||
|             case "dot": | ||||
|                 float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point); | ||||
|                 return invert ? -dotRet : dotRet; | ||||
|             default: | ||||
|                 float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) | ||||
|                                 .getFinalResult().floatValue(); | ||||
|                 return invert ? -ret4 : ret4; | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,83 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.sptree; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class HeapItem implements Serializable, Comparable<HeapItem> { | ||||
|     private int index; | ||||
|     private double distance; | ||||
| 
 | ||||
| 
 | ||||
|     public HeapItem(int index, double distance) { | ||||
|         this.index = index; | ||||
|         this.distance = distance; | ||||
|     } | ||||
| 
 | ||||
|     public int getIndex() { | ||||
|         return index; | ||||
|     } | ||||
| 
 | ||||
|     public void setIndex(int index) { | ||||
|         this.index = index; | ||||
|     } | ||||
| 
 | ||||
|     public double getDistance() { | ||||
|         return distance; | ||||
|     } | ||||
| 
 | ||||
|     public void setDistance(double distance) { | ||||
|         this.distance = distance; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean equals(Object o) { | ||||
|         if (this == o) | ||||
|             return true; | ||||
|         if (o == null || getClass() != o.getClass()) | ||||
|             return false; | ||||
| 
 | ||||
|         HeapItem heapItem = (HeapItem) o; | ||||
| 
 | ||||
|         if (index != heapItem.index) | ||||
|             return false; | ||||
|         return Double.compare(heapItem.distance, distance) == 0; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         int result; | ||||
|         long temp; | ||||
|         result = index; | ||||
|         temp = Double.doubleToLongBits(distance); | ||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int compareTo(HeapItem o) { | ||||
|         return distance < o.distance ? 1 : 0; | ||||
|     } | ||||
| } | ||||
| @ -1,72 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.sptree; | ||||
| 
 | ||||
| import lombok.Data; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @Data | ||||
| public class HeapObject implements Serializable, Comparable<HeapObject> { | ||||
|     private int index; | ||||
|     private INDArray point; | ||||
|     private double distance; | ||||
| 
 | ||||
| 
 | ||||
|     public HeapObject(int index, INDArray point, double distance) { | ||||
|         this.index = index; | ||||
|         this.point = point; | ||||
|         this.distance = distance; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Override | ||||
|     public boolean equals(Object o) { | ||||
|         if (this == o) | ||||
|             return true; | ||||
|         if (o == null || getClass() != o.getClass()) | ||||
|             return false; | ||||
| 
 | ||||
|         HeapObject heapObject = (HeapObject) o; | ||||
| 
 | ||||
|         if (!point.equals(heapObject.point)) | ||||
|             return false; | ||||
| 
 | ||||
|         return Double.compare(heapObject.distance, distance) == 0; | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int hashCode() { | ||||
|         int result; | ||||
|         long temp; | ||||
|         result = index; | ||||
|         temp = Double.doubleToLongBits(distance); | ||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); | ||||
|         return result; | ||||
|     } | ||||
| 
 | ||||
|     @Override | ||||
|     public int compareTo(HeapObject o) { | ||||
|         return distance < o.distance ? 1 : 0; | ||||
|     } | ||||
| } | ||||
| @ -1,425 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.sptree; | ||||
| 
 | ||||
| import org.nd4j.shade.guava.util.concurrent.AtomicDouble; | ||||
| import lombok.val; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.nn.conf.WorkspaceMode; | ||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.ArrayList; | ||||
| import java.util.Collection; | ||||
| import java.util.Set; | ||||
| 
 | ||||
| 
 | ||||
| /** | ||||
|  * @author Adam Gibson | ||||
|  */ | ||||
| public class SpTree implements Serializable { | ||||
| 
 | ||||
| 
 | ||||
|     public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL"; | ||||
| 
 | ||||
| 
 | ||||
|     private int D; | ||||
|     private INDArray data; | ||||
|     public final static int NODE_RATIO = 8000; | ||||
|     private int N; | ||||
|     private int size; | ||||
|     private int cumSize; | ||||
|     private Cell boundary; | ||||
|     private INDArray centerOfMass; | ||||
|     private SpTree parent; | ||||
|     private int[] index; | ||||
|     private int nodeCapacity; | ||||
|     private int numChildren = 2; | ||||
|     private boolean isLeaf = true; | ||||
|     private Collection<INDArray> indices; | ||||
|     private SpTree[] children; | ||||
|     private static Logger log = LoggerFactory.getLogger(SpTree.class); | ||||
|     private String similarityFunction = Distance.EUCLIDEAN.toString(); | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, | ||||
|                   String similarityFunction) { | ||||
|         init(parent, data, corner, width, indices, similarityFunction); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) { | ||||
|         this.indices = indices; | ||||
|         this.N = data.rows(); | ||||
|         this.D = data.columns(); | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         data = data.dup(); | ||||
|         INDArray meanY = data.mean(0); | ||||
|         INDArray minY = data.min(0); | ||||
|         INDArray maxY = data.max(0); | ||||
|         INDArray width = Nd4j.create(data.dataType(), meanY.shape()); | ||||
|         for (int i = 0; i < width.length(); i++) { | ||||
|             width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i), | ||||
|                     meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD); | ||||
|         } | ||||
| 
 | ||||
|         try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { | ||||
|             init(null, data, meanY, width, indices, similarityFunction); | ||||
|             fill(N); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) { | ||||
|         this(parent, data, corner, width, indices, "euclidean"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree(INDArray data, Collection<INDArray> indices) { | ||||
|         this(data, indices, "euclidean"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree(INDArray data) { | ||||
|         this(data, new ArrayList<INDArray>()); | ||||
|     } | ||||
| 
 | ||||
|     public MemoryWorkspace workspace() { | ||||
|         return null; | ||||
|     } | ||||
| 
 | ||||
|     private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, | ||||
|                       String similarityFunction) { | ||||
| 
 | ||||
|         this.parent = parent; | ||||
|         D = data.columns(); | ||||
|         N = data.rows(); | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         nodeCapacity = N % NODE_RATIO; | ||||
|         index = new int[nodeCapacity]; | ||||
|         for (int d = 1; d < this.D; d++) | ||||
|             numChildren *= 2; | ||||
|         this.indices = indices; | ||||
|         isLeaf = true; | ||||
|         size = 0; | ||||
|         cumSize = 0; | ||||
|         children = new SpTree[numChildren]; | ||||
|         this.data = data; | ||||
|         boundary = new Cell(D); | ||||
|         boundary.setCorner(corner.dup()); | ||||
|         boundary.setWidth(width.dup()); | ||||
|         centerOfMass = Nd4j.create(data.dataType(), D); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     private boolean insert(int index) { | ||||
|         /*MemoryWorkspace workspace = | ||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() | ||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( | ||||
|                         workspaceConfigurationExternal, | ||||
|                         workspaceExternal); | ||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { | ||||
| 
 | ||||
|             INDArray point = data.slice(index); | ||||
|             /*boolean contains = false; | ||||
|             SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains); | ||||
|             Nd4j.getExecutioner().exec(op); | ||||
|             op.getOutputArgument(0).getScalar(0); | ||||
|             if (!contains) return false;*/ | ||||
|             if (!boundary.contains(point)) | ||||
|                 return false; | ||||
| 
 | ||||
| 
 | ||||
|             cumSize++; | ||||
|             double mult1 = (double) (cumSize - 1) / (double) cumSize; | ||||
|             double mult2 = 1.0 / (double) cumSize; | ||||
|             centerOfMass.muli(mult1); | ||||
|             centerOfMass.addi(point.mul(mult2)); | ||||
|             // If there is space in this quad tree and it is a leaf, add the object here | ||||
|             if (isLeaf() && size < nodeCapacity) { | ||||
|                 this.index[size] = index; | ||||
|                 indices.add(point); | ||||
|                 size++; | ||||
|                 return true; | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             for (int i = 0; i < size; i++) { | ||||
|                 INDArray compPoint = data.slice(this.index[i]); | ||||
|                 if (compPoint.equals(point)) | ||||
|                     return true; | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|             if (isLeaf()) | ||||
|                 subDivide(); | ||||
| 
 | ||||
| 
 | ||||
|             // Find out where the point can be inserted | ||||
|             for (int i = 0; i < numChildren; i++) { | ||||
|                 if (children[i].insert(index)) | ||||
|                     return true; | ||||
|             } | ||||
| 
 | ||||
|             throw new IllegalStateException("Shouldn't reach this state"); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Subdivide the node in to | ||||
|      * 4 children | ||||
|      */ | ||||
|     public void subDivide() { | ||||
|         /*MemoryWorkspace workspace = | ||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() | ||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( | ||||
|                         workspaceConfigurationExternal, | ||||
|                         workspaceExternal); | ||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{ | ||||
| 
 | ||||
|             INDArray newCorner = Nd4j.create(data.dataType(), D); | ||||
|             INDArray newWidth = Nd4j.create(data.dataType(), D); | ||||
|             for (int i = 0; i < numChildren; i++) { | ||||
|                 int div = 1; | ||||
|                 for (int d = 0; d < D; d++) { | ||||
|                     newWidth.putScalar(d, .5 * boundary.width(d)); | ||||
|                     if ((i / div) % 2 == 1) | ||||
|                         newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d)); | ||||
|                     else | ||||
|                         newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d)); | ||||
|                     div *= 2; | ||||
|                 } | ||||
| 
 | ||||
|                 children[i] = new SpTree(this, data, newCorner, newWidth, indices); | ||||
| 
 | ||||
|             } | ||||
| 
 | ||||
|             // Move existing points to correct children | ||||
|             for (int i = 0; i < size; i++) { | ||||
|                 boolean success = false; | ||||
|                 for (int j = 0; j < this.numChildren; j++) | ||||
|                     if (!success) | ||||
|                         success = children[j].insert(index[i]); | ||||
| 
 | ||||
|                 index[i] = -1; | ||||
|             } | ||||
| 
 | ||||
|             // Empty parent node | ||||
|             size = 0; | ||||
|             isLeaf = false; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Compute non edge forces using barnes hut | ||||
|      * @param pointIndex | ||||
|      * @param theta | ||||
|      * @param negativeForce | ||||
|      * @param sumQ | ||||
|      */ | ||||
|     public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { | ||||
|         // Make sure that we spend no time on empty nodes or self-interactions | ||||
|         INDArray buf = Nd4j.create(data.dataType(), this.D); | ||||
| 
 | ||||
|         if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) | ||||
|             return; | ||||
|        /* MemoryWorkspace workspace = | ||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() | ||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( | ||||
|                         workspaceConfigurationExternal, | ||||
|                         workspaceExternal); | ||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { | ||||
| 
 | ||||
|             // Compute distance between point and center-of-mass | ||||
|             data.slice(pointIndex).subi(centerOfMass, buf); | ||||
| 
 | ||||
|             double D = Nd4j.getBlasWrapper().dot(buf, buf); | ||||
|             // Check whether we can use this node as a "summary" | ||||
|             double maxWidth = boundary.width().maxNumber().doubleValue(); | ||||
|             // Check whether we can use this node as a "summary" | ||||
|             if (isLeaf() || maxWidth / Math.sqrt(D) < theta) { | ||||
| 
 | ||||
|                 // Compute and add t-SNE force between point and current node | ||||
|                 double Q = 1.0 / (1.0 + D); | ||||
|                 double mult = cumSize * Q; | ||||
|                 sumQ.addAndGet(mult); | ||||
|                 mult *= Q; | ||||
|                 negativeForce.addi(buf.mul(mult)); | ||||
|             } else { | ||||
| 
 | ||||
|                 // Recursively apply Barnes-Hut to children | ||||
|                 for (int i = 0; i < numChildren; i++) { | ||||
|                     children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); | ||||
|                 } | ||||
| 
 | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * Compute edge forces using barnes hut | ||||
|      * @param rowP a vector | ||||
|      * @param colP | ||||
|      * @param valP | ||||
|      * @param N the number of elements | ||||
|      * @param posF the positive force | ||||
|      */ | ||||
|     public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { | ||||
|         if (!rowP.isVector()) | ||||
|             throw new IllegalArgumentException("RowP must be a vector"); | ||||
| 
 | ||||
|         // Loop over all edges in the graph | ||||
|         // just execute native op | ||||
|         Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF)); | ||||
| 
 | ||||
|         /* | ||||
|         INDArray buf = Nd4j.create(data.dataType(), this.D); | ||||
|         double D; | ||||
|         for (int n = 0; n < N; n++) { | ||||
|             INDArray slice = data.slice(n); | ||||
|             for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { | ||||
| 
 | ||||
|                 // Compute pairwise distance and Q-value | ||||
|                 slice.subi(data.slice(colP.getInt(i)), buf); | ||||
| 
 | ||||
|                 D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf); | ||||
|                 D = valP.getDouble(i) / D; | ||||
| 
 | ||||
|                 // Sum positive force | ||||
|                 posF.slice(n).addi(buf.muli(D)); | ||||
|             } | ||||
|         } | ||||
|         */ | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     public boolean isLeaf() { | ||||
|         return isLeaf; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Verifies the structure of the tree (does bounds checking on each node) | ||||
|      * @return true if the structure of the tree | ||||
|      * is correct. | ||||
|      */ | ||||
|     public boolean isCorrect() { | ||||
|         /*MemoryWorkspace workspace = | ||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() | ||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( | ||||
|                         workspaceConfigurationExternal, | ||||
|                         workspaceExternal); | ||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { | ||||
| 
 | ||||
|             for (int n = 0; n < size; n++) { | ||||
|                 INDArray point = data.slice(index[n]); | ||||
|                 if (!boundary.contains(point)) | ||||
|                     return false; | ||||
|             } | ||||
|             if (!isLeaf()) { | ||||
|                 boolean correct = true; | ||||
|                 for (int i = 0; i < numChildren; i++) | ||||
|                     correct = correct && children[i].isCorrect(); | ||||
|                 return correct; | ||||
|             } | ||||
| 
 | ||||
|             return true; | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * The depth of the node | ||||
|      * @return the depth of the node | ||||
|      */ | ||||
|     public int depth() { | ||||
|         if (isLeaf()) | ||||
|             return 1; | ||||
|         int depth = 1; | ||||
|         int maxChildDepth = 0; | ||||
|         for (int i = 0; i < numChildren; i++) { | ||||
|             maxChildDepth = Math.max(maxChildDepth, children[0].depth()); | ||||
|         } | ||||
| 
 | ||||
|         return depth + maxChildDepth; | ||||
|     } | ||||
| 
 | ||||
|     private void fill(int n) { | ||||
|         if (indices.isEmpty() && parent == null) | ||||
|             for (int i = 0; i < n; i++) { | ||||
|                 log.trace("Inserted " + i); | ||||
|                 insert(i); | ||||
|             } | ||||
|         else | ||||
|             log.warn("Called fill already"); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public SpTree[] getChildren() { | ||||
|         return children; | ||||
|     } | ||||
| 
 | ||||
|     public int getD() { | ||||
|         return D; | ||||
|     } | ||||
| 
 | ||||
|     public INDArray getCenterOfMass() { | ||||
|         return centerOfMass; | ||||
|     } | ||||
| 
 | ||||
|     public Cell getBoundary() { | ||||
|         return boundary; | ||||
|     } | ||||
| 
 | ||||
|     public int[] getIndex() { | ||||
|         return index; | ||||
|     } | ||||
| 
 | ||||
|     public int getCumSize() { | ||||
|         return cumSize; | ||||
|     } | ||||
| 
 | ||||
|     public void setCumSize(int cumSize) { | ||||
|         this.cumSize = cumSize; | ||||
|     } | ||||
| 
 | ||||
|     public int getNumChildren() { | ||||
|         return numChildren; | ||||
|     } | ||||
| 
 | ||||
|     public void setNumChildren(int numChildren) { | ||||
|         this.numChildren = numChildren; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,117 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.strategy; | ||||
| 
 | ||||
| import lombok.*; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; | ||||
| import org.deeplearning4j.clustering.condition.ConvergenceCondition; | ||||
| import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| 
 | ||||
| @AllArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable { | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected ClusteringStrategyType type; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected Integer initialClusterCount; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected ClusteringAlgorithmCondition optimizationPhaseCondition; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected ClusteringAlgorithmCondition terminationCondition; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected boolean inverse; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected Distance distanceFunction; | ||||
|     @Getter(AccessLevel.PUBLIC) | ||||
|     @Setter(AccessLevel.PROTECTED) | ||||
|     protected boolean allowEmptyClusters; | ||||
| 
 | ||||
|     public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction, | ||||
|                     boolean allowEmptyClusters, boolean inverse) { | ||||
|         this.type = type; | ||||
|         this.initialClusterCount = initialClusterCount; | ||||
|         this.distanceFunction = distanceFunction; | ||||
|         this.allowEmptyClusters = allowEmptyClusters; | ||||
|         this.inverse = inverse; | ||||
|     } | ||||
| 
 | ||||
|     public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount, | ||||
|                     Distance distanceFunction, boolean inverse) { | ||||
|         this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param maxIterationCount | ||||
|      * @return | ||||
|      */ | ||||
|     public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) { | ||||
|         setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount)); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param rate | ||||
|      * @return | ||||
|      */ | ||||
|     public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) { | ||||
|         setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate)); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public boolean inverseDistanceCalculation() { | ||||
|         return inverse; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param type | ||||
|      * @return | ||||
|      */ | ||||
|     public boolean isStrategyOfType(ClusteringStrategyType type) { | ||||
|         return type.equals(this.type); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     public Integer getInitialClusterCount() { | ||||
|         return initialClusterCount; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,102 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.strategy; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| public interface ClusteringStrategy { | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     boolean inverseDistanceCalculation(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     ClusteringStrategyType getType(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param type | ||||
|      * @return | ||||
|      */ | ||||
|     boolean isStrategyOfType(ClusteringStrategyType type); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     Integer getInitialClusterCount(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     Distance getDistanceFunction(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     boolean isAllowEmptyClusters(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     ClusteringAlgorithmCondition getTerminationCondition(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @return | ||||
|      */ | ||||
|     boolean isOptimizationDefined(); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param iterationHistory | ||||
|      * @return | ||||
|      */ | ||||
|     boolean isOptimizationApplicableNow(IterationHistory iterationHistory); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param maxIterationCount | ||||
|      * @return | ||||
|      */ | ||||
|     BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount); | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param rate | ||||
|      * @return | ||||
|      */ | ||||
|     BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate); | ||||
| 
 | ||||
| } | ||||
| @ -1,25 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.strategy; | ||||
| 
 | ||||
| public enum ClusteringStrategyType { | ||||
|     FIXED_CLUSTER_COUNT, OPTIMIZATION | ||||
| } | ||||
| @ -1,68 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.strategy; | ||||
| 
 | ||||
| import lombok.AccessLevel; | ||||
| import lombok.NoArgsConstructor; | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| 
 | ||||
| /** | ||||
|  * | ||||
|  */ | ||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) | ||||
| public class FixedClusterCountStrategy extends BaseClusteringStrategy { | ||||
| 
 | ||||
| 
 | ||||
|     protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction, | ||||
|                     boolean allowEmptyClusters, boolean inverse) { | ||||
|         super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters, | ||||
|                         inverse); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param clusterCount | ||||
|      * @param distanceFunction | ||||
|      * @param inverse | ||||
|      * @return | ||||
|      */ | ||||
|     public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) { | ||||
|         return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * @return | ||||
|      */ | ||||
|     @Override | ||||
|     public boolean inverseDistanceCalculation() { | ||||
|         return inverse; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isOptimizationDefined() { | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,82 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.strategy; | ||||
| 
 | ||||
| import org.deeplearning4j.clustering.algorithm.Distance; | ||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; | ||||
| import org.deeplearning4j.clustering.condition.ConvergenceCondition; | ||||
| import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; | ||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; | ||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimization; | ||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; | ||||
| 
 | ||||
| public class OptimisationStrategy extends BaseClusteringStrategy { | ||||
|     public static int defaultIterationCount = 100; | ||||
| 
 | ||||
|     private ClusteringOptimization clusteringOptimisation; | ||||
|     private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition; | ||||
| 
 | ||||
|     protected OptimisationStrategy() { | ||||
|         super(); | ||||
|     } | ||||
| 
 | ||||
|     protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) { | ||||
|         super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false); | ||||
|     } | ||||
| 
 | ||||
|     public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) { | ||||
|         return new OptimisationStrategy(initialClusterCount, distanceFunction); | ||||
|     } | ||||
| 
 | ||||
|     public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) { | ||||
|         clusteringOptimisation = new ClusteringOptimization(type, value); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) { | ||||
|         clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
|     public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) { | ||||
|         clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate); | ||||
|         return this; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     public double getClusteringOptimizationValue() { | ||||
|         return clusteringOptimisation.getValue(); | ||||
|     } | ||||
| 
 | ||||
|     public boolean isClusteringOptimizationType(ClusteringOptimizationType type) { | ||||
|         return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type); | ||||
|     } | ||||
| 
 | ||||
|     public boolean isOptimizationDefined() { | ||||
|         return clusteringOptimisation != null; | ||||
|     } | ||||
| 
 | ||||
|     public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { | ||||
|         return clusteringOptimisationApplicationCondition != null | ||||
|                         && clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory); | ||||
|     } | ||||
| 
 | ||||
| } | ||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,74 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.util; | ||||
| 
 | ||||
| import org.slf4j.Logger; | ||||
| import org.slf4j.LoggerFactory; | ||||
| 
 | ||||
| import java.util.List; | ||||
| import java.util.concurrent.*; | ||||
| 
 | ||||
| public class MultiThreadUtils { | ||||
| 
 | ||||
|     private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); | ||||
| 
 | ||||
|     private static ExecutorService instance; | ||||
| 
 | ||||
|     private MultiThreadUtils() {} | ||||
| 
 | ||||
|     public static synchronized ExecutorService newExecutorService() { | ||||
|         int nThreads = Runtime.getRuntime().availableProcessors(); | ||||
|         return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(), | ||||
|                         new ThreadFactory() { | ||||
|                             @Override | ||||
|                             public Thread newThread(Runnable r) { | ||||
|                                 Thread t = Executors.defaultThreadFactory().newThread(r); | ||||
|                                 t.setDaemon(true); | ||||
|                                 return t; | ||||
|                             } | ||||
|                         }); | ||||
|     } | ||||
| 
 | ||||
|     public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) { | ||||
|         int tasksCount = tasks.size(); | ||||
|         final CountDownLatch latch = new CountDownLatch(tasksCount); | ||||
|         for (int i = 0; i < tasksCount; i++) { | ||||
|             final int taskIdx = i; | ||||
|             executorService.execute(new Runnable() { | ||||
|                 public void run() { | ||||
|                     try { | ||||
|                         tasks.get(taskIdx).run(); | ||||
|                     } catch (Throwable e) { | ||||
|                         log.info("Unchecked exception thrown by task", e); | ||||
|                     } finally { | ||||
|                         latch.countDown(); | ||||
|                     } | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
|         try { | ||||
|             latch.await(); | ||||
|         } catch (Exception e) { | ||||
|             throw new RuntimeException(e); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,61 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.util; | ||||
| 
 | ||||
| import java.util.Collection; | ||||
| import java.util.HashSet; | ||||
| import java.util.Set; | ||||
| 
 | ||||
| public class SetUtils { | ||||
|     private SetUtils() {} | ||||
| 
 | ||||
|     // Set specific operations | ||||
| 
 | ||||
|     public static <T> Set<T> intersection(Collection<T> parentCollection, Collection<T> removeFromCollection) { | ||||
|         Set<T> results = new HashSet<>(parentCollection); | ||||
|         results.retainAll(removeFromCollection); | ||||
|         return results; | ||||
|     } | ||||
| 
 | ||||
|     public static <T> boolean intersectionP(Set<? extends T> s1, Set<? extends T> s2) { | ||||
|         for (T elt : s1) { | ||||
|             if (s2.contains(elt)) | ||||
|                 return true; | ||||
|         } | ||||
|         return false; | ||||
|     } | ||||
| 
 | ||||
|     public static <T> Set<T> union(Set<? extends T> s1, Set<? extends T> s2) { | ||||
|         Set<T> s3 = new HashSet<>(s1); | ||||
|         s3.addAll(s2); | ||||
|         return s3; | ||||
|     } | ||||
| 
 | ||||
|     /** Return is s1 \ s2 */ | ||||
| 
 | ||||
|     public static <T> Set<T> difference(Collection<? extends T> s1, Collection<? extends T> s2) { | ||||
|         Set<T> s3 = new HashSet<>(s1); | ||||
|         s3.removeAll(s2); | ||||
|         return s3; | ||||
|     } | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| @ -1,633 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.vptree; | ||||
| 
 | ||||
| import lombok.*; | ||||
| import lombok.extern.slf4j.Slf4j; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.deeplearning4j.clustering.sptree.HeapObject; | ||||
| import org.deeplearning4j.clustering.util.MathUtils; | ||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; | ||||
| import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; | ||||
| import org.nd4j.linalg.api.memory.enums.*; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; | ||||
| import org.nd4j.linalg.exception.ND4JIllegalStateException; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.io.Serializable; | ||||
| import java.util.*; | ||||
| import java.util.concurrent.*; | ||||
| import java.util.concurrent.atomic.AtomicInteger; | ||||
| 
 | ||||
| @Slf4j | ||||
| @Builder | ||||
| @AllArgsConstructor | ||||
| public class VPTree implements Serializable { | ||||
|     private static final long serialVersionUID = 1L; | ||||
| 
 | ||||
|     public static final String EUCLIDEAN = "euclidean"; | ||||
|     private double tau; | ||||
|     @Getter | ||||
|     @Setter | ||||
|     private INDArray items; | ||||
|     private List<INDArray> itemsList; | ||||
|     private Node root; | ||||
|     private String similarityFunction; | ||||
|     @Getter | ||||
|     private boolean invert = false; | ||||
|     private transient ExecutorService executorService; | ||||
|     @Getter | ||||
|     private int workers = 1; | ||||
|     private AtomicInteger size = new AtomicInteger(0); | ||||
| 
 | ||||
|     private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>(); | ||||
| 
 | ||||
|     private WorkspaceConfiguration workspaceConfiguration; | ||||
| 
 | ||||
|     protected VPTree() { | ||||
|         // method for serialization only | ||||
|         scalars = new ThreadLocal<>(); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param points | ||||
|      * @param invert | ||||
|      */ | ||||
|     public VPTree(INDArray points, boolean invert) { | ||||
|         this(points, "euclidean", 1, invert); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param points | ||||
|      * @param invert | ||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) | ||||
|      */ | ||||
|     public VPTree(INDArray points, boolean invert, int workers) { | ||||
|         this(points, "euclidean", workers, invert); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items the items to use | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @param invert whether to invert the distance (similarity functions have different min/max objectives) | ||||
|      */ | ||||
|     public VPTree(INDArray items, String similarityFunction, boolean invert) { | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         this.invert = invert; | ||||
|         this.items = items; | ||||
|         root = buildFromPoints(items); | ||||
|         workers = 1; | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items the items to use | ||||
|      * @param similarityFunction the similarity function to use | ||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) | ||||
|      * @param invert whether to invert the metric (different optimization objective) | ||||
|      */ | ||||
|     public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) { | ||||
|         this.workers = workers; | ||||
| 
 | ||||
|         val list = new INDArray[items.size()]; | ||||
| 
 | ||||
|         // build list of INDArrays first | ||||
|         for (int i = 0; i < items.size(); i++) | ||||
|             list[i] = items.get(i).getPoint(); | ||||
|             //this.items.putRow(i, items.get(i).getPoint()); | ||||
| 
 | ||||
|         // just stack them out with concat :) | ||||
|         this.items = Nd4j.pile(list); | ||||
| 
 | ||||
|         this.invert = invert; | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         root = buildFromPoints(this.items); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items | ||||
|      * @param similarityFunction | ||||
|      */ | ||||
|     public VPTree(INDArray items, String similarityFunction) { | ||||
|         this(items, similarityFunction, 1, false); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items | ||||
|      * @param similarityFunction | ||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) | ||||
|      * @param invert | ||||
|      */ | ||||
|     public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) { | ||||
|         this.similarityFunction = similarityFunction; | ||||
|         this.invert = invert; | ||||
|         this.items = items; | ||||
| 
 | ||||
|         this.workers = workers; | ||||
|         root = buildFromPoints(items); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items | ||||
|      * @param similarityFunction | ||||
|      */ | ||||
|     public VPTree(List<DataPoint> items, String similarityFunction) { | ||||
|         this(items, similarityFunction, 1, false); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items | ||||
|      */ | ||||
|     public VPTree(INDArray items) { | ||||
|         this(items, EUCLIDEAN); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param items | ||||
|      */ | ||||
|     public VPTree(List<DataPoint> items) { | ||||
|         this(items, EUCLIDEAN); | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * Create an ndarray | ||||
|      * from the datapoints | ||||
|      * @param data | ||||
|      * @return | ||||
|      */ | ||||
|     public static INDArray buildFromData(List<DataPoint> data) { | ||||
|         INDArray ret = Nd4j.create(data.size(), data.get(0).getD()); | ||||
|         for (int i = 0; i < ret.slices(); i++) | ||||
|             ret.putSlice(i, data.get(i).getPoint()); | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param basePoint | ||||
|      * @param distancesArr | ||||
|      */ | ||||
|     public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) { | ||||
|         switch (similarityFunction) { | ||||
|             case "euclidean": | ||||
|                 Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1)); | ||||
|                 break; | ||||
|             case "cosinedistance": | ||||
|                 Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
|             case "cosinesimilarity": | ||||
|                 Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
|             case "manhattan": | ||||
|                 Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
|             case "dot": | ||||
|                 Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1)); | ||||
|                 break; | ||||
|             case "jaccard": | ||||
|                 Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
|             case "hamming": | ||||
|                 Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
|             default: | ||||
|                 Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1)); | ||||
|                 break; | ||||
| 
 | ||||
|         } | ||||
| 
 | ||||
|         if (invert) | ||||
|             distancesArr.negi(); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) { | ||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     /** | ||||
|      * Euclidean distance | ||||
|      * @return the distance between the two points | ||||
|      */ | ||||
|     public double distance(INDArray arr1, INDArray arr2) { | ||||
|         if (scalars == null) | ||||
|             scalars = new ThreadLocal<>(); | ||||
| 
 | ||||
|         if (scalars.get() == null) | ||||
|             scalars.set(Nd4j.scalar(arr1.dataType(), 0.0)); | ||||
| 
 | ||||
|         switch (similarityFunction) { | ||||
|             case "jaccard": | ||||
|                 double ret7 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new JaccardDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret7 : ret7; | ||||
|             case "hamming": | ||||
|                 double ret8 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new HammingDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret8 : ret8; | ||||
|             case "euclidean": | ||||
|                 double ret = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret : ret; | ||||
|             case "cosinesimilarity": | ||||
|                 double ret2 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret2 : ret2; | ||||
|             case "cosinedistance": | ||||
|                 double ret6 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new CosineDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret6 : ret6; | ||||
|             case "manhattan": | ||||
|                 double ret3 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret3 : ret3; | ||||
|             case "dot": | ||||
|                 double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2); | ||||
|                 return invert ? -dotRet : dotRet; | ||||
|             default: | ||||
|                 double ret4 = Nd4j.getExecutioner() | ||||
|                         .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) | ||||
|                         .getFinalResult().doubleValue(); | ||||
|                 return invert ? -ret4 : ret4; | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     protected class NodeBuilder implements Callable<Node> { | ||||
|         protected List<INDArray> list; | ||||
|         protected List<Integer> indices; | ||||
| 
 | ||||
|         public NodeBuilder(List<INDArray> list, List<Integer> indices) { | ||||
|             this.list = list; | ||||
|             this.indices = indices; | ||||
|         } | ||||
| 
 | ||||
|         @Override | ||||
|         public Node call() throws Exception { | ||||
|             return buildFromPoints(list, indices); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     private Node buildFromPoints(List<INDArray> points, List<Integer> indices) { | ||||
|         Node ret = new Node(0, 0); | ||||
| 
 | ||||
| 
 | ||||
|         // nothing to sort here | ||||
|         if (points.size() == 1) { | ||||
|             ret.point = points.get(0); | ||||
|             ret.index = indices.get(0); | ||||
|             return ret; | ||||
|         } | ||||
| 
 | ||||
|         // opening workspace, and creating it if that's the first call | ||||
|        /* MemoryWorkspace workspace = | ||||
|                 Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ | ||||
| 
 | ||||
|         INDArray items = Nd4j.vstack(points); | ||||
|         int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); | ||||
|         INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint); | ||||
|         ret.point = basePoint; | ||||
|         ret.index = indices.get(randomPoint); | ||||
|         INDArray distancesArr = Nd4j.create(items.rows(), 1); | ||||
| 
 | ||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); | ||||
| 
 | ||||
|         double medianDistance = distancesArr.medianNumber().doubleValue(); | ||||
| 
 | ||||
|         ret.threshold = (float) medianDistance; | ||||
| 
 | ||||
|         List<INDArray> leftPoints = new ArrayList<>(); | ||||
|         List<Integer> leftIndices = new ArrayList<>(); | ||||
|         List<INDArray> rightPoints = new ArrayList<>(); | ||||
|         List<Integer> rightIndices = new ArrayList<>(); | ||||
| 
 | ||||
|         for (int i = 0; i < distancesArr.length(); i++) { | ||||
|             if (i == randomPoint) | ||||
|                 continue; | ||||
| 
 | ||||
|             if (distancesArr.getDouble(i) < medianDistance) { | ||||
|                 leftPoints.add(points.get(i)); | ||||
|                 leftIndices.add(indices.get(i)); | ||||
|             } else { | ||||
|                 rightPoints.add(points.get(i)); | ||||
|                 rightIndices.add(indices.get(i)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // closing workspace | ||||
|         //workspace.notifyScopeLeft(); | ||||
|         //log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes()); | ||||
| 
 | ||||
|         if (workers > 1) { | ||||
|             if (!leftPoints.isEmpty()) | ||||
|                 ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints); | ||||
| 
 | ||||
|             if (!rightPoints.isEmpty()) | ||||
|                 ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices)); | ||||
|         } else { | ||||
|             if (!leftPoints.isEmpty()) | ||||
|                 ret.left = buildFromPoints(leftPoints, leftIndices); | ||||
| 
 | ||||
|             if (!rightPoints.isEmpty()) | ||||
|                 ret.right = buildFromPoints(rightPoints, rightIndices); | ||||
|         } | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     private Node buildFromPoints(INDArray items) { | ||||
|         if (executorService == null && items == this.items && workers > 1) { | ||||
|             final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); | ||||
| 
 | ||||
|             executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { | ||||
|                 @Override | ||||
|                 public Thread newThread(final Runnable r) { | ||||
|                     Thread t = new Thread(new Runnable() { | ||||
| 
 | ||||
|                         @Override | ||||
|                         public void run() { | ||||
|                             Nd4j.getAffinityManager().unsafeSetDevice(deviceId); | ||||
|                             r.run(); | ||||
|                         } | ||||
|                     }); | ||||
| 
 | ||||
|                     t.setDaemon(true); | ||||
|                     t.setName("VPTree thread"); | ||||
| 
 | ||||
|                     return t; | ||||
|                 } | ||||
|             }); | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         final Node ret = new Node(0, 0); | ||||
|         size.incrementAndGet(); | ||||
| 
 | ||||
|         /*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) | ||||
|                 .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) | ||||
|                 .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) | ||||
|                 .policySpill(SpillPolicy.REALLOCATE).build(); | ||||
| 
 | ||||
|         // opening workspace | ||||
|         MemoryWorkspace workspace = | ||||
|                 Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ | ||||
| 
 | ||||
|         int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); | ||||
|         INDArray basePoint = items.getRow(randomPoint, true); | ||||
|         INDArray distancesArr = Nd4j.create(items.rows(), 1); | ||||
|         ret.point = basePoint; | ||||
|         ret.index = randomPoint; | ||||
| 
 | ||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); | ||||
| 
 | ||||
|         double medianDistance = distancesArr.medianNumber().doubleValue(); | ||||
| 
 | ||||
|         ret.threshold = (float) medianDistance; | ||||
| 
 | ||||
|         List<INDArray> leftPoints = new ArrayList<>(); | ||||
|         List<Integer> leftIndices = new ArrayList<>(); | ||||
|         List<INDArray> rightPoints = new ArrayList<>(); | ||||
|         List<Integer> rightIndices = new ArrayList<>(); | ||||
| 
 | ||||
|         for (int i = 0; i < distancesArr.length(); i++) { | ||||
|             if (i == randomPoint) | ||||
|                 continue; | ||||
| 
 | ||||
|             if (distancesArr.getDouble(i) < medianDistance) { | ||||
|                 leftPoints.add(items.getRow(i, true)); | ||||
|                 leftIndices.add(i); | ||||
|             } else { | ||||
|                 rightPoints.add(items.getRow(i, true)); | ||||
|                 rightIndices.add(i); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|         // closing workspace | ||||
|         //workspace.notifyScopeLeft(); | ||||
|         //workspace.destroyWorkspace(true); | ||||
| 
 | ||||
|         if (!leftPoints.isEmpty()) | ||||
|             ret.left = buildFromPoints(leftPoints, leftIndices); | ||||
| 
 | ||||
|         if (!rightPoints.isEmpty()) | ||||
|             ret.right = buildFromPoints(rightPoints, rightIndices); | ||||
| 
 | ||||
|         // destroy once again | ||||
|         //workspace.destroyWorkspace(true); | ||||
| 
 | ||||
|         if (ret.left != null) | ||||
|             ret.left.fetchFutures(); | ||||
| 
 | ||||
|         if (ret.right != null) | ||||
|             ret.right.fetchFutures(); | ||||
| 
 | ||||
|         if (executorService != null) | ||||
|             executorService.shutdown(); | ||||
| 
 | ||||
|         return ret; | ||||
|     } | ||||
| 
 | ||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) { | ||||
|         search(target, k, results, distances, true); | ||||
|     } | ||||
| 
 | ||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, | ||||
|                        boolean filterEqual) { | ||||
|         search(target, k, results, distances, filterEqual, false); | ||||
|     } | ||||
|     /** | ||||
|      * | ||||
|      * @param target | ||||
|      * @param k | ||||
|      * @param results | ||||
|      * @param distances | ||||
|      */ | ||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, | ||||
|                        boolean filterEqual, boolean dropEdge) { | ||||
|         if (items != null) | ||||
|             if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1) | ||||
|                 throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " | ||||
|                         + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); | ||||
| 
 | ||||
|         k = Math.min(k, items.rows()); | ||||
|         results.clear(); | ||||
|         distances.clear(); | ||||
| 
 | ||||
|         PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); | ||||
| 
 | ||||
|         search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE); | ||||
| 
 | ||||
|         while (!pq.isEmpty()) { | ||||
|             HeapObject ho = pq.peek(); | ||||
|             results.add(new DataPoint(ho.getIndex(), ho.getPoint())); | ||||
|             distances.add(ho.getDistance()); | ||||
|             pq.poll(); | ||||
|         } | ||||
| 
 | ||||
|         Collections.reverse(results); | ||||
|         Collections.reverse(distances); | ||||
| 
 | ||||
|         if (dropEdge || results.size() > k) { | ||||
|             if (filterEqual && distances.get(0) == 0.0) { | ||||
|                 results.remove(0); | ||||
|                 distances.remove(0); | ||||
|             } | ||||
| 
 | ||||
|             while (results.size() > k) { | ||||
|                 results.remove(results.size() - 1); | ||||
|                 distances.remove(distances.size() - 1); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     /** | ||||
|      * | ||||
|      * @param node | ||||
|      * @param target | ||||
|      * @param k | ||||
|      * @param pq | ||||
|      */ | ||||
|     public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) { | ||||
| 
 | ||||
|         if (node == null) | ||||
|             return; | ||||
| 
 | ||||
|         double tau = cTau; | ||||
| 
 | ||||
|         INDArray get = node.getPoint(); //items.getRow(node.getIndex()); | ||||
|         double distance = distance(get, target); | ||||
|         if (distance < tau) { | ||||
|            if (pq.size() == k) | ||||
|               pq.poll(); | ||||
| 
 | ||||
|             pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance)); | ||||
|             if (pq.size() == k) | ||||
|                tau = pq.peek().getDistance(); | ||||
|          } | ||||
| 
 | ||||
|         Node left = node.getLeft(); | ||||
|         Node right = node.getRight(); | ||||
| 
 | ||||
|         if (left == null && right == null) | ||||
|             return; | ||||
| 
 | ||||
|         if (distance < node.getThreshold()) { | ||||
|             if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first | ||||
|                 search(left, target, k, pq, tau); | ||||
|             } | ||||
| 
 | ||||
|             if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child | ||||
|                 search(right, target, k, pq, tau); | ||||
|             } | ||||
| 
 | ||||
|         } else { | ||||
|             if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first | ||||
|                 search(right, target, k, pq, tau); | ||||
|             } | ||||
| 
 | ||||
|             if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child | ||||
|                 search(left, target, k, pq, tau); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     protected class HeapObjectComparator implements Comparator<HeapObject> { | ||||
| 
 | ||||
|         @Override | ||||
|         public int compare(HeapObject o1, HeapObject o2) { | ||||
|             return Double.compare(o2.getDistance(), o1.getDistance()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Data | ||||
|     public static class Node implements Serializable { | ||||
|         private static final long serialVersionUID = 2L; | ||||
| 
 | ||||
|         private int index; | ||||
|         private float threshold; | ||||
|         private Node left, right; | ||||
|         private INDArray point; | ||||
|         protected transient Future<Node> futureLeft; | ||||
|         protected transient Future<Node> futureRight; | ||||
| 
 | ||||
|         public Node(int index, float threshold) { | ||||
|             this.index = index; | ||||
|             this.threshold = threshold; | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|         public void fetchFutures() { | ||||
|             try { | ||||
|                 if (futureLeft != null) { | ||||
|                     /*while (!futureLeft.isDone()) | ||||
|                         Thread.sleep(100);*/ | ||||
| 
 | ||||
| 
 | ||||
|                     left = futureLeft.get(); | ||||
|                 } | ||||
| 
 | ||||
|                 if (futureRight != null) { | ||||
|                     /*while (!futureRight.isDone()) | ||||
|                         Thread.sleep(100);*/ | ||||
| 
 | ||||
|                     right = futureRight.get(); | ||||
|                 } | ||||
| 
 | ||||
| 
 | ||||
|                 if (left != null) | ||||
|                     left.fetchFutures(); | ||||
| 
 | ||||
|                 if (right != null) | ||||
|                     right.fetchFutures(); | ||||
|             } catch (Exception e) { | ||||
|                 throw new RuntimeException(e); | ||||
|             } | ||||
| 
 | ||||
| 
 | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
| } | ||||
| @ -1,79 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.vptree; | ||||
| 
 | ||||
| import lombok.Getter; | ||||
| import org.deeplearning4j.clustering.sptree.DataPoint; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class VPTreeFillSearch { | ||||
|     private VPTree vpTree; | ||||
|     private int k; | ||||
|     @Getter | ||||
|     private List<DataPoint> results; | ||||
|     @Getter | ||||
|     private List<Double> distances; | ||||
|     private INDArray target; | ||||
| 
 | ||||
|     public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) { | ||||
|         this.vpTree = vpTree; | ||||
|         this.k = k; | ||||
|         this.target = target; | ||||
|     } | ||||
| 
 | ||||
|     public void search() { | ||||
|         results = new ArrayList<>(); | ||||
|         distances = new ArrayList<>(); | ||||
|         //initial search | ||||
|         //vpTree.search(target,k,results,distances); | ||||
| 
 | ||||
|         //fill till there is k results | ||||
|         //by going down the list | ||||
|         //   if(results.size() < k) { | ||||
|         INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); | ||||
|         vpTree.calcDistancesRelativeTo(target, distancesArr); | ||||
|         INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); | ||||
|         results.clear(); | ||||
|         distances.clear(); | ||||
|         if (vpTree.getItems().isVector()) { | ||||
|             for (int i = 0; i < k; i++) { | ||||
|                 int idx = sortWithIndices[0].getInt(i); | ||||
|                 results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); | ||||
|                 distances.add(sortWithIndices[1].getDouble(idx)); | ||||
|             } | ||||
|         } else { | ||||
|             for (int i = 0; i < k; i++) { | ||||
|                 int idx = sortWithIndices[0].getInt(i); | ||||
|                 results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); | ||||
|                 //distances.add(sortWithIndices[1].getDouble(idx)); | ||||
|                 distances.add(sortWithIndices[1].getDouble(i)); | ||||
|             } | ||||
|         } | ||||
| 
 | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| } | ||||
| @ -1,21 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.vptree; | ||||
| @ -1,46 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.cluster; | ||||
| 
 | ||||
| import org.junit.Assert; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.List; | ||||
| 
 | ||||
| public class ClusterSetTest { | ||||
|     @Test | ||||
|     public void testGetMostPopulatedClusters() { | ||||
|         ClusterSet clusterSet = new ClusterSet(false); | ||||
|         List<Cluster> clusters = new ArrayList<>(); | ||||
|         for (int i = 0; i < 5; i++) { | ||||
|             Cluster cluster = new Cluster(); | ||||
|             cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); | ||||
|             clusters.add(cluster); | ||||
|         } | ||||
|         clusterSet.setClusters(clusters); | ||||
|         List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); | ||||
|         for (int i = 0; i < 5; i++) { | ||||
|             Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); | ||||
|         } | ||||
|     } | ||||
| } | ||||
| @ -1,422 +0,0 @@ | ||||
| /* | ||||
|  *  ****************************************************************************** | ||||
|  *  * | ||||
|  *  * | ||||
|  *  * This program and the accompanying materials are made available under the | ||||
|  *  * terms of the Apache License, Version 2.0 which is available at | ||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. | ||||
|  *  * | ||||
|  *  *  See the NOTICE file distributed with this work for additional | ||||
|  *  *  information regarding copyright ownership. | ||||
|  *  * Unless required by applicable law or agreed to in writing, software | ||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||||
|  *  * License for the specific language governing permissions and limitations | ||||
|  *  * under the License. | ||||
|  *  * | ||||
|  *  * SPDX-License-Identifier: Apache-2.0 | ||||
|  *  ***************************************************************************** | ||||
|  */ | ||||
| 
 | ||||
| package org.deeplearning4j.clustering.kdtree; | ||||
| 
 | ||||
| import lombok.val; | ||||
| import org.deeplearning4j.BaseDL4JTest; | ||||
| import org.joda.time.Duration; | ||||
| import org.junit.Before; | ||||
| import org.junit.BeforeClass; | ||||
| import org.junit.Ignore; | ||||
| import org.junit.Test; | ||||
| import org.nd4j.linalg.api.buffer.DataType; | ||||
| import org.nd4j.linalg.api.ndarray.INDArray; | ||||
| import org.nd4j.linalg.factory.Nd4j; | ||||
| import org.nd4j.common.primitives.Pair; | ||||
| import org.nd4j.shade.guava.base.Stopwatch; | ||||
| import org.nd4j.shade.guava.primitives.Doubles; | ||||
| import org.nd4j.shade.guava.primitives.Floats; | ||||
| 
 | ||||
| import java.util.ArrayList; | ||||
| import java.util.Arrays; | ||||
| import java.util.List; | ||||
| import java.util.Random; | ||||
| 
 | ||||
| import static java.util.concurrent.TimeUnit.MILLISECONDS; | ||||
| import static java.util.concurrent.TimeUnit.SECONDS; | ||||
| import static org.junit.Assert.assertEquals; | ||||
| import static org.junit.Assert.assertTrue; | ||||
| 
 | ||||
| public class KDTreeTest extends BaseDL4JTest { | ||||
| 
 | ||||
|     @Override | ||||
|     public long getTimeoutMilliseconds() { | ||||
|         return 120000L; | ||||
|     } | ||||
| 
 | ||||
|     private KDTree kdTree; | ||||
| 
 | ||||
|     @BeforeClass | ||||
|     public static void beforeClass(){ | ||||
|         Nd4j.setDataType(DataType.FLOAT); | ||||
|     } | ||||
| 
 | ||||
|     @Before | ||||
|     public void setUp() { | ||||
|          kdTree = new KDTree(2); | ||||
|         float[] data = new float[]{7,2}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{5,4}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{2,3}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{4,7}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{9,6}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{8,1}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testTree() { | ||||
|         KDTree tree = new KDTree(2); | ||||
|         INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT); | ||||
|         INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT); | ||||
|         tree.insert(half); | ||||
|         tree.insert(one); | ||||
|         Pair<Double, INDArray> pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT)); | ||||
|         assertEquals(half, pair.getValue()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testInsert() { | ||||
|         int elements = 10; | ||||
|         List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); | ||||
| 
 | ||||
|         KDTree kdTree = new KDTree(digits.size()); | ||||
|         List<List<Double>> lists = new ArrayList<>(); | ||||
|         for (int i = 0; i < elements; i++) { | ||||
|             List<Double> thisList = new ArrayList<>(digits.size()); | ||||
|             for (int k = 0; k < digits.size(); k++) { | ||||
|                 thisList.add(digits.get(k) + i); | ||||
|             } | ||||
|             lists.add(thisList); | ||||
|         } | ||||
| 
 | ||||
|         for (int i = 0; i < elements; i++) { | ||||
|             double[] features = Doubles.toArray(lists.get(i)); | ||||
|             INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); | ||||
|             kdTree.insert(ind); | ||||
|             assertEquals(i + 1, kdTree.size()); | ||||
|         } | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testDelete() { | ||||
|         int elements = 10; | ||||
|         List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); | ||||
| 
 | ||||
|         KDTree kdTree = new KDTree(digits.size()); | ||||
|         List<List<Double>> lists = new ArrayList<>(); | ||||
|         for (int i = 0; i < elements; i++) { | ||||
|             List<Double> thisList = new ArrayList<>(digits.size()); | ||||
|             for (int k = 0; k < digits.size(); k++) { | ||||
|                 thisList.add(digits.get(k) + i); | ||||
|             } | ||||
|             lists.add(thisList); | ||||
|         } | ||||
| 
 | ||||
|         INDArray toDelete = Nd4j.empty(DataType.DOUBLE), | ||||
|                  leafToDelete = Nd4j.empty(DataType.DOUBLE); | ||||
|         for (int i = 0; i < elements; i++) { | ||||
|             double[] features = Doubles.toArray(lists.get(i)); | ||||
|             INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); | ||||
|             if (i == 1) | ||||
|                 toDelete = ind; | ||||
|             if (i == elements - 1) { | ||||
|                 leafToDelete = ind; | ||||
|             } | ||||
|             kdTree.insert(ind); | ||||
|             assertEquals(i + 1, kdTree.size()); | ||||
|         } | ||||
| 
 | ||||
|         kdTree.delete(toDelete); | ||||
|         assertEquals(9, kdTree.size()); | ||||
|         kdTree.delete(leafToDelete); | ||||
|         assertEquals(8, kdTree.size()); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testNN() { | ||||
|         int n = 10; | ||||
| 
 | ||||
|         // make a KD-tree of dimension {#n} | ||||
|         KDTree kdTree = new KDTree(n); | ||||
|         for (int i = -1; i < n; i++) { | ||||
|             // Insert a unit vector along each dimension | ||||
|             List<Double> vec = new ArrayList<>(n); | ||||
|             // i = -1 ensures the origin is in the Tree | ||||
|             for (int k = 0; k < n; k++) { | ||||
|                 vec.add((k == i) ? 1.0 : 0.0); | ||||
|             } | ||||
|             INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT); | ||||
|             kdTree.insert(indVec); | ||||
|         } | ||||
|         Random rand = new Random(); | ||||
| 
 | ||||
|         // random point in the Hypercube | ||||
|         List<Double> pt = new ArrayList(n); | ||||
|         for (int k = 0; k < n; k++) { | ||||
|             pt.add(rand.nextDouble()); | ||||
|         } | ||||
|         Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT)); | ||||
| 
 | ||||
|         // Always true for points in the unitary hypercube | ||||
|         assertTrue(result.getKey() < Double.MAX_VALUE); | ||||
| 
 | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN() { | ||||
|         int dimensions = 512; | ||||
|         int vectorsNo = isIntegrationTests() ? 50000 : 1000; | ||||
|         // make a KD-tree of dimension {#dimensions} | ||||
|         Stopwatch stopwatch = Stopwatch.createStarted(); | ||||
|         KDTree kdTree = new KDTree(dimensions); | ||||
|         for (int i = -1; i < vectorsNo; i++) { | ||||
|             // Insert a unit vector along each dimension | ||||
|             INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); | ||||
|             kdTree.insert(indVec); | ||||
|         } | ||||
|         stopwatch.stop(); | ||||
|         System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); | ||||
| 
 | ||||
|         Random rand = new Random(); | ||||
|         // random point in the Hypercube | ||||
|         List<Double> pt = new ArrayList(dimensions); | ||||
|         for (int k = 0; k < dimensions; k++) { | ||||
|             pt.add(rand.nextFloat() * 10.0); | ||||
|         } | ||||
|         stopwatch.reset(); | ||||
|         stopwatch.start(); | ||||
|         List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); | ||||
|         stopwatch.stop(); | ||||
|         System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_Simple() { | ||||
|         int n = 2; | ||||
|         KDTree kdTree = new KDTree(n); | ||||
| 
 | ||||
|         float[] data = new float[]{3,3}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{1,1}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
|         data = new float[]{2,2}; | ||||
|         kdTree.insert(Nd4j.createFromArray(data)); | ||||
| 
 | ||||
|         data = new float[]{0,0}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); | ||||
| 
 | ||||
|         assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); | ||||
| 
 | ||||
|         assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); | ||||
| 
 | ||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_1() { | ||||
| 
 | ||||
|         assertEquals(6, kdTree.size()); | ||||
| 
 | ||||
|         float[] data = new float[]{8,1}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); | ||||
|         assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_2() { | ||||
|         float[] data = new float[]{8, 1}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); | ||||
|         assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_3() { | ||||
| 
 | ||||
|         float[] data = new float[]{2, 3}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); | ||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_4() { | ||||
|         float[] data = new float[]{2, 3}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); | ||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testKNN_5() { | ||||
|         float[] data = new float[]{2, 3}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); | ||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void test_KNN_6() { | ||||
|         float[] data = new float[]{4, 6}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); | ||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void test_KNN_7() { | ||||
|         float[] data = new float[]{4, 6}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); | ||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void test_KNN_8() { | ||||
|         float[] data = new float[]{4, 6}; | ||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); | ||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); | ||||
|         assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); | ||||
|         assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); | ||||
|     } | ||||
| 
 | ||||
|     @Test | ||||
|     public void testNoDuplicates() { | ||||
|         int N = 100; | ||||
|         KDTree bigTree = new KDTree(2); | ||||
| 
 | ||||
|         List<INDArray> points = new ArrayList<>(); | ||||
|         for (int i = 0; i < N; ++i) { | ||||
|             double[] data = new double[]{i, i}; | ||||
|             points.add(Nd4j.createFromArray(data)); | ||||
|         } | ||||
| 
 | ||||
|         for (int i = 0; i < N; ++i) { | ||||
|             bigTree.insert(points.get(i)); | ||||
|         } | ||||
| 
 | ||||
|         assertEquals(N, bigTree.size()); | ||||
| 
 | ||||
|         INDArray node = Nd4j.empty(DataType.DOUBLE); | ||||
|         for (int i = 0; i < N; ++i) { | ||||
|             node = bigTree.delete(node.isEmpty() ? points.get(i) : node); | ||||
|         } | ||||
| 
 | ||||
|         assertEquals(0, bigTree.size()); | ||||
|     } | ||||
| 
 | ||||
|     @Ignore | ||||
|     @Test | ||||
|     public void performanceTest() { | ||||
|         int n = 2; | ||||
|         int num = 100000; | ||||
|         // make a KD-tree of dimension {#n} | ||||
|         long start = System.currentTimeMillis(); | ||||
|         KDTree kdTree = new KDTree(n); | ||||
|         INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n); | ||||
|         for (int  i = 0 ; i < num; ++i) { | ||||
|             kdTree.insert(inputArrray.getRow(i)); | ||||
|         } | ||||
| 
 | ||||
|         long end = System.currentTimeMillis(); | ||||
|         Duration duration = new Duration(start, end); | ||||
|         System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); | ||||
| 
 | ||||
|         List<Float> pt = new ArrayList(num); | ||||
|         for (int k = 0; k < n; k++) { | ||||
|             pt.add((float)(num / 2)); | ||||
|         } | ||||
|         start = System.currentTimeMillis(); | ||||
|         List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); | ||||
|         end = System.currentTimeMillis(); | ||||
|         duration = new Duration(start, end); | ||||
|         long elapsed = end - start; | ||||
|         System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis()); | ||||
|         for (val pair : list) { | ||||
|             System.out.println(pair.getFirst() + " " + pair.getSecond()) ; | ||||
|         } | ||||
|     } | ||||
| } | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user