Remove more unused modules
This commit is contained in:
		
							parent
							
								
									fa8537f0c7
								
							
						
					
					
						commit
						ee06fdd16f
					
				| @ -1,64 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-spark-inference-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-spark-inference-client</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-spark-inference-client</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark-inference-server_2.11</artifactId> |  | ||||||
|             <version>1.0.0-SNAPSHOT</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark-inference-model</artifactId> |  | ||||||
|             <version>${project.parent.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.mashape.unirest</groupId> |  | ||||||
|             <artifactId>unirest-java</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,292 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.client; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import com.mashape.unirest.http.exceptions.UnirestException; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; |  | ||||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Slf4j |  | ||||||
| public class DataVecTransformClient implements DataVecTransformService { |  | ||||||
|     private String url; |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         // Only one time |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (JsonProcessingException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transformProcess |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { |  | ||||||
|         try { |  | ||||||
|             String s = transformProcess.toJson(); |  | ||||||
|             Unirest.post(url + "/transformprocess").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").body(s).asJson(); |  | ||||||
| 
 |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in setCSVTransformProcess()", e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public TransformProcess getCSVTransformProcess() { |  | ||||||
|         try { |  | ||||||
|             String s = Unirest.get(url + "/transformprocess").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").asString().getBody(); |  | ||||||
|             return TransformProcess.fromJson(s); |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in getCSVTransformProcess()",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ImageTransformProcess getImageTransformProcess() { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { |  | ||||||
|         try { |  | ||||||
|             SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .body(transform).asObject(SingleCSVRecord.class).getBody(); |  | ||||||
|             return singleCsvRecord; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformIncremental(SingleCSVRecord)",e); |  | ||||||
|         } |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"TRUE") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(SequenceBatchCSVRecord.class) |  | ||||||
|                     .getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"FALSE") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(BatchCSVRecord.class) |  | ||||||
|                     .getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transform(BatchCSVRecord)", e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json").body(batchCSVRecord) |  | ||||||
|                     .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return batchArray1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformArray(BatchCSVRecord)",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") |  | ||||||
|                     .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return array; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformArrayIncremental(SingleCSVRecord)",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return array; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceArrayIncremental",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
|             return batchArray1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceArray",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(batchCSVRecord) |  | ||||||
|                     .asObject(SequenceBatchCSVRecord.class).getBody(); |  | ||||||
|             return batchCSVRecord1; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequence"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { |  | ||||||
|         try { |  | ||||||
|             SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental") |  | ||||||
|                     .header("accept", "application/json") |  | ||||||
|                     .header("Content-Type", "application/json") |  | ||||||
|                     .header(SEQUENCE_OR_NOT_HEADER,"true") |  | ||||||
|                     .body(transform).asObject(SequenceBatchCSVRecord.class).getBody(); |  | ||||||
|             return singleCsvRecord; |  | ||||||
|         } catch (UnirestException e) { |  | ||||||
|             log.error("Error in transformSequenceIncremental"); |  | ||||||
|         } |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,45 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.transform.client; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.transform.client"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,139 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.transform.client; |  | ||||||
| 
 |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; |  | ||||||
| import org.datavec.spark.inference.client.DataVecTransformClient; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.net.ServerSocket; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| import static org.junit.Assume.assumeNotNull; |  | ||||||
| 
 |  | ||||||
| public class DataVecTransformClientTest { |  | ||||||
|     private static CSVSparkTransformServer server; |  | ||||||
|     private static int port = getAvailablePort(); |  | ||||||
|     private static DataVecTransformClient client; |  | ||||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|     private static TransformProcess transformProcess = |  | ||||||
|             new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); |  | ||||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void beforeClass() throws Exception { |  | ||||||
|         FileUtils.write(fileSave, transformProcess.toJson()); |  | ||||||
|         fileSave.deleteOnExit(); |  | ||||||
|         server = new CSVSparkTransformServer(); |  | ||||||
|         server.runMain(new String[] {"-dp", String.valueOf(port)}); |  | ||||||
| 
 |  | ||||||
|         client = new DataVecTransformClient("http://localhost:" + port); |  | ||||||
|         client.setCSVTransformProcess(transformProcess); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void afterClass() throws Exception { |  | ||||||
|         server.stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testSequenceClient() { |  | ||||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); |  | ||||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); |  | ||||||
|         List<BatchCSVRecord> batchCSVRecordList = new ArrayList<>(); |  | ||||||
|         for(int i = 0; i < 5; i++) { |  | ||||||
|              batchCSVRecordList.add(batchCSVRecord); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         sequenceBatchCSVRecord.add(batchCSVRecordList); |  | ||||||
| 
 |  | ||||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord); |  | ||||||
|         assumeNotNull(sequenceBatchCSVRecord1); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord); |  | ||||||
|         assumeNotNull(array); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord); |  | ||||||
|         assumeNotNull(incrementalBody); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord); |  | ||||||
|         assumeNotNull(incrementalSequenceBody); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testRecord() throws Exception { |  | ||||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); |  | ||||||
|         SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord); |  | ||||||
|         assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size()); |  | ||||||
|         Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord); |  | ||||||
|         INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
|         assumeNotNull(arr); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testBatchRecord() throws Exception { |  | ||||||
|         SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"}); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord)); |  | ||||||
|         BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord); |  | ||||||
|         assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size()); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody body = client.transformArray(batchCSVRecord); |  | ||||||
|         INDArray arr = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
|         assumeNotNull(arr); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static int getAvailablePort() { |  | ||||||
|         try { |  | ||||||
|             ServerSocket socket = new ServerSocket(0); |  | ||||||
|             try { |  | ||||||
|                 return socket.getLocalPort(); |  | ||||||
|             } finally { |  | ||||||
|                 socket.close(); |  | ||||||
|             } |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,6 +0,0 @@ | |||||||
| play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule |  | ||||||
| play.modules.enabled += io.skymind.skil.service.PredictionModule |  | ||||||
| play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk |  | ||||||
| play.server.pidfile.path=/tmp/RUNNING_PID |  | ||||||
| 
 |  | ||||||
| play.server.http.port = 9600 |  | ||||||
| @ -1,63 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-spark-inference-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-spark-inference-model</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-spark-inference-model</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-api</artifactId> |  | ||||||
|             <version>${datavec.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-data-image</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-local</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,286 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import lombok.val; |  | ||||||
| import org.apache.arrow.memory.BufferAllocator; |  | ||||||
| import org.apache.arrow.memory.RootAllocator; |  | ||||||
| import org.apache.arrow.vector.FieldVector; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.util.ndarray.RecordConverter; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.arrow.ArrowConverter; |  | ||||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordBatch; |  | ||||||
| import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch; |  | ||||||
| import org.datavec.local.transforms.LocalTransformExecutor; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static org.datavec.arrow.ArrowConverter.*; |  | ||||||
| import static org.datavec.local.transforms.LocalTransformExecutor.execute; |  | ||||||
| import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Slf4j |  | ||||||
| public class CSVSparkTransform { |  | ||||||
|     @Getter |  | ||||||
|     private TransformProcess transformProcess; |  | ||||||
|     private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert a raw record via |  | ||||||
|      * the {@link TransformProcess} |  | ||||||
|      * to a base 64ed ndarray |  | ||||||
|      * @param batch the record to convert |  | ||||||
|      * @return teh base 64ed ndarray |  | ||||||
|      * @throws IOException |  | ||||||
|      */ |  | ||||||
|     public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException { |  | ||||||
|         List<List<Writable>> converted =  execute(toArrowWritables(toArrowColumnsString( |  | ||||||
|                 bufferAllocator,transformProcess.getInitialSchema(), |  | ||||||
|                 batch.getRecordsAsString()), |  | ||||||
|                 transformProcess.getInitialSchema()),transformProcess); |  | ||||||
| 
 |  | ||||||
|         ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted; |  | ||||||
|         INDArray convert = ArrowConverter.toArray(arrowRecordBatch); |  | ||||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert a raw record via |  | ||||||
|      * the {@link TransformProcess} |  | ||||||
|      * to a base 64ed ndarray |  | ||||||
|      * @param record the record to convert |  | ||||||
|      * @return the base 64ed ndarray |  | ||||||
|      * @throws IOException |  | ||||||
|      */ |  | ||||||
|     public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException { |  | ||||||
|         List<Writable> record2 = toArrowWritablesSingle( |  | ||||||
|                 toArrowColumnsStringSingle(bufferAllocator, |  | ||||||
|                         transformProcess.getInitialSchema(),record.getValues()), |  | ||||||
|                 transformProcess.getInitialSchema()); |  | ||||||
|         List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); |  | ||||||
|         INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord); |  | ||||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(convert)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Runs the transform process |  | ||||||
|      * @param batch the record to transform |  | ||||||
|      * @return the transformed record |  | ||||||
|      */ |  | ||||||
|     public BatchCSVRecord transform(BatchCSVRecord batch) { |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         List<List<Writable>> converted =  execute(toArrowWritables(toArrowColumnsString( |  | ||||||
|                 bufferAllocator,transformProcess.getInitialSchema(), |  | ||||||
|                 batch.getRecordsAsString()), |  | ||||||
|                 transformProcess.getInitialSchema()),transformProcess); |  | ||||||
|         int numCols = converted.get(0).size(); |  | ||||||
|         for (int row = 0; row < converted.size(); row++) { |  | ||||||
|             String[] values = new String[numCols]; |  | ||||||
|             for (int i = 0; i < values.length; i++) |  | ||||||
|                 values[i] = converted.get(row).get(i).toString(); |  | ||||||
|             batchCSVRecord.add(new SingleCSVRecord(values)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return batchCSVRecord; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Runs the transform process |  | ||||||
|      * @param record the record to transform |  | ||||||
|      * @return the transformed record |  | ||||||
|      */ |  | ||||||
|     public SingleCSVRecord transform(SingleCSVRecord record) { |  | ||||||
|         List<Writable> record2 = toArrowWritablesSingle( |  | ||||||
|                 toArrowColumnsStringSingle(bufferAllocator, |  | ||||||
|                         transformProcess.getInitialSchema(),record.getValues()), |  | ||||||
|                 transformProcess.getInitialSchema()); |  | ||||||
|         List<Writable> finalRecord = execute(Arrays.asList(record2),transformProcess).get(0); |  | ||||||
|         String[] values = new String[finalRecord.size()]; |  | ||||||
|         for (int i = 0; i < values.length; i++) |  | ||||||
|             values[i] = finalRecord.get(i).toString(); |  | ||||||
|         return new SingleCSVRecord(values); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { |  | ||||||
|         /** |  | ||||||
|          * Sequence schema? |  | ||||||
|          */ |  | ||||||
|         List<List<List<Writable>>> converted = executeToSequence( |  | ||||||
|                 toArrowWritables(toArrowColumnsStringTimeSeries( |  | ||||||
|                         bufferAllocator, transformProcess.getInitialSchema(), |  | ||||||
|                         Arrays.asList(transform.getRecordsAsString())), |  | ||||||
|                         transformProcess.getInitialSchema()), transformProcess); |  | ||||||
| 
 |  | ||||||
|         SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < converted.size(); i++) { |  | ||||||
|             BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i)); |  | ||||||
|             batchCSVRecord.add(Arrays.asList(batchCSVRecord1)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return batchCSVRecord; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchCSVRecordSequence |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) { |  | ||||||
|         List<List<List<String>>> recordsAsString = batchCSVRecordSequence.getRecordsAsString(); |  | ||||||
|         boolean allSameLength = true; |  | ||||||
|         Integer length = null; |  | ||||||
|         for(List<List<String>> record : recordsAsString) { |  | ||||||
|             if(length == null) { |  | ||||||
|                 length = record.size(); |  | ||||||
|             } |  | ||||||
|             else if(record.size() != length)  { |  | ||||||
|                 allSameLength = false; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if(allSameLength) { |  | ||||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString); |  | ||||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors, |  | ||||||
|                     transformProcess.getInitialSchema(), |  | ||||||
|                     recordsAsString.get(0).get(0).size()); |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); |  | ||||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         else { |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); |  | ||||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * TODO: optimize |  | ||||||
|      * @param batchCSVRecordSequence |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) { |  | ||||||
|         List<List<List<String>>> strings = batchCSVRecordSequence.getRecordsAsString(); |  | ||||||
|         boolean allSameLength = true; |  | ||||||
|         Integer length = null; |  | ||||||
|         for(List<List<String>> record : strings) { |  | ||||||
|             if(length == null) { |  | ||||||
|                 length = record.size(); |  | ||||||
|             } |  | ||||||
|             else if(record.size() != length)  { |  | ||||||
|                 allSameLength = false; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if(allSameLength) { |  | ||||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); |  | ||||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); |  | ||||||
|             INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); |  | ||||||
|             try { |  | ||||||
|                 return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); |  | ||||||
|             } catch (IOException e) { |  | ||||||
|                 throw new IllegalStateException(e); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         else { |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); |  | ||||||
|             INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size()); |  | ||||||
|             try { |  | ||||||
|                 return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); |  | ||||||
|             } catch (IOException e) { |  | ||||||
|                 throw new IllegalStateException(e); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { |  | ||||||
|         List<List<List<Writable>>> converted =  executeToSequence(toArrowWritables(toArrowColumnsString( |  | ||||||
|                 bufferAllocator,transformProcess.getInitialSchema(), |  | ||||||
|                 singleCsvRecord.getRecordsAsString()), |  | ||||||
|                 transformProcess.getInitialSchema()),transformProcess); |  | ||||||
|         ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted; |  | ||||||
|         INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch); |  | ||||||
|         try { |  | ||||||
|             return new Base64NDArrayBody(Nd4jBase64.base64String(arr)); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("",e); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         List<List<List<String>>> strings = batchCSVRecord.getRecordsAsString(); |  | ||||||
|         boolean allSameLength = true; |  | ||||||
|         Integer length = null; |  | ||||||
|         for(List<List<String>> record : strings) { |  | ||||||
|             if(length == null) { |  | ||||||
|                 length = record.size(); |  | ||||||
|             } |  | ||||||
|             else if(record.size() != length)  { |  | ||||||
|                 allSameLength = false; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if(allSameLength) { |  | ||||||
|             List<FieldVector> fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings); |  | ||||||
|             ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size()); |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess); |  | ||||||
|              return SequenceBatchCSVRecord.fromWritables(transformed); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         else { |  | ||||||
|             val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess); |  | ||||||
|             return SequenceBatchCSVRecord.fromWritables(transformed); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,64 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Getter; |  | ||||||
| import org.datavec.image.data.ImageWritable; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class ImageSparkTransform { |  | ||||||
|     @Getter |  | ||||||
|     private ImageTransformProcess imageTransformProcess; |  | ||||||
| 
 |  | ||||||
|     public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException { |  | ||||||
|         ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri()); |  | ||||||
|         INDArray finalRecord = imageTransformProcess.executeArray(record2); |  | ||||||
| 
 |  | ||||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException { |  | ||||||
|         List<INDArray> records = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         for (SingleImageRecord imgRecord : batch.getRecords()) { |  | ||||||
|             ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri()); |  | ||||||
|             INDArray finalRecord = imageTransformProcess.executeArray(record2); |  | ||||||
|             records.add(finalRecord); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()])); |  | ||||||
| 
 |  | ||||||
|         return new Base64NDArrayBody(Nd4jBase64.base64String(array)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,32 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class Base64NDArrayBody { |  | ||||||
|     private String ndarray; |  | ||||||
| } |  | ||||||
| @ -1,104 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Builder |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class BatchCSVRecord implements Serializable { |  | ||||||
|     private List<SingleCSVRecord> records; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the records as a list of strings |  | ||||||
|      * (basically the underlying values for |  | ||||||
|      * {@link SingleCSVRecord}) |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public List<List<String>> getRecordsAsString() { |  | ||||||
|         if(records == null) |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
|         List<List<String>> ret = new ArrayList<>(); |  | ||||||
|         for(SingleCSVRecord csvRecord : records) { |  | ||||||
|             ret.add(csvRecord.getValues()); |  | ||||||
|         } |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create a batch csv record |  | ||||||
|      * from a list of writables. |  | ||||||
|      * @param batch |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static BatchCSVRecord fromWritables(List<List<Writable>> batch) { |  | ||||||
|         List <SingleCSVRecord> records = new ArrayList<>(batch.size()); |  | ||||||
|         for(List<Writable> list : batch) { |  | ||||||
|             List<String> add = new ArrayList<>(list.size()); |  | ||||||
|             for(Writable writable : list) { |  | ||||||
|                 add.add(writable.toString()); |  | ||||||
|             } |  | ||||||
|             records.add(new SingleCSVRecord(add)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return BatchCSVRecord.builder().records(records).build(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a record |  | ||||||
|      * @param record |  | ||||||
|      */ |  | ||||||
|     public void add(SingleCSVRecord record) { |  | ||||||
|         if (records == null) |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
|         records.add(record); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Return a batch record based on a dataset |  | ||||||
|      * @param dataSet the dataset to get the batch record for |  | ||||||
|      * @return the batch record |  | ||||||
|      */ |  | ||||||
|     public static BatchCSVRecord fromDataSet(DataSet dataSet) { |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < dataSet.numExamples(); i++) { |  | ||||||
|             batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return batchCSVRecord; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,50 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.net.URI; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class BatchImageRecord { |  | ||||||
|     private List<SingleImageRecord> records; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a record |  | ||||||
|      * @param record |  | ||||||
|      */ |  | ||||||
|     public void add(SingleImageRecord record) { |  | ||||||
|         if (records == null) |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
|         records.add(record); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void add(URI uri) { |  | ||||||
|         this.add(new SingleImageRecord(uri)); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,106 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| import org.nd4j.linalg.dataset.MultiDataSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Builder |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class SequenceBatchCSVRecord implements Serializable { |  | ||||||
|     private List<List<BatchCSVRecord>> records; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a record |  | ||||||
|      * @param record |  | ||||||
|      */ |  | ||||||
|     public void add(List<BatchCSVRecord> record) { |  | ||||||
|         if (records == null) |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
|         records.add(record); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the records as a list of strings directly |  | ||||||
|      * (this basically "unpacks" the objects) |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public List<List<List<String>>> getRecordsAsString() { |  | ||||||
|         if(records == null) |  | ||||||
|             Collections.emptyList(); |  | ||||||
|         List<List<List<String>>> ret = new ArrayList<>(records.size()); |  | ||||||
|         for(List<BatchCSVRecord> record : records) { |  | ||||||
|             List<List<String>> add = new ArrayList<>(); |  | ||||||
|             for(BatchCSVRecord batchCSVRecord : record) { |  | ||||||
|                 for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) { |  | ||||||
|                     add.add(singleCSVRecord.getValues()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             ret.add(add); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Convert a writables time series to a sequence batch |  | ||||||
|      * @param input |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static SequenceBatchCSVRecord fromWritables(List<List<List<Writable>>> input) { |  | ||||||
|         SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord(); |  | ||||||
|         for(int i = 0; i < input.size(); i++) { |  | ||||||
|             ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i)))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Return a batch record based on a dataset |  | ||||||
|      * @param dataSet the dataset to get the batch record for |  | ||||||
|      * @return the batch record |  | ||||||
|      */ |  | ||||||
|     public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) { |  | ||||||
|         SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < dataSet.numFeatureArrays(); i++) { |  | ||||||
|             batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i))))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return batchCSVRecord; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,95 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class SingleCSVRecord implements Serializable { |  | ||||||
|     private List<String> values; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create from an array of values uses list internally) |  | ||||||
|      * @param values |  | ||||||
|      */ |  | ||||||
|     public SingleCSVRecord(String...values) { |  | ||||||
|         this.values = Arrays.asList(values); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Instantiate a csv record from a vector |  | ||||||
|      * given either an input dataset and a |  | ||||||
|      * one hot matrix, the index will be appended to |  | ||||||
|      * the end of the record, or for regression |  | ||||||
|      * it will append all values in the labels |  | ||||||
|      * @param row the input vectors |  | ||||||
|      * @return the record from this {@link DataSet} |  | ||||||
|      */ |  | ||||||
|     public static SingleCSVRecord fromRow(DataSet row) { |  | ||||||
|         if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) |  | ||||||
|             throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); |  | ||||||
|         if (!row.getLabels().isVector() && !row.getLabels().isScalar()) |  | ||||||
|             throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); |  | ||||||
|         //classification |  | ||||||
|         SingleCSVRecord record; |  | ||||||
|         int idx = 0; |  | ||||||
|         if (row.getLabels().sumNumber().doubleValue() == 1.0) { |  | ||||||
|             String[] values = new String[row.getFeatures().columns() + 1]; |  | ||||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); |  | ||||||
|             } |  | ||||||
|             int maxIdx = 0; |  | ||||||
|             for (int i = 0; i < row.getLabels().length(); i++) { |  | ||||||
|                 if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { |  | ||||||
|                     maxIdx = i; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             values[idx++] = String.valueOf(maxIdx); |  | ||||||
|             record = new SingleCSVRecord(values); |  | ||||||
|         } |  | ||||||
|         //regression (any number of values) |  | ||||||
|         else { |  | ||||||
|             String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; |  | ||||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); |  | ||||||
|             } |  | ||||||
|             for (int i = 0; i < row.getLabels().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getLabels().getDouble(i)); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             record = new SingleCSVRecord(values); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|         return record; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,34 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.net.URI; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class SingleImageRecord { |  | ||||||
|     private URI uri; |  | ||||||
| } |  | ||||||
| @ -1,131 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.model.service; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| public interface DataVecTransformService { |  | ||||||
| 
 |  | ||||||
|     String SEQUENCE_OR_NOT_HEADER = "Sequence"; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param transformProcess |  | ||||||
|      */ |  | ||||||
|     void setCSVTransformProcess(TransformProcess transformProcess); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param imageTransformProcess |  | ||||||
|      */ |  | ||||||
|     void setImageTransformProcess(ImageTransformProcess imageTransformProcess); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     TransformProcess getCSVTransformProcess(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     ImageTransformProcess getImageTransformProcess(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord); |  | ||||||
| 
 |  | ||||||
|     SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     BatchCSVRecord transform(BatchCSVRecord batchCSVRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param singleImageRecord |  | ||||||
|      * @return |  | ||||||
|      * @throws IOException |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchImageRecord |  | ||||||
|      * @return |  | ||||||
|      * @throws IOException |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform); |  | ||||||
| } |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.spark.transform"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,40 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class BatchCSVRecordTest { |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testBatchRecordCreationFromDataSet() { |  | ||||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet); |  | ||||||
|         assertEquals(2, batchCSVRecord.getRecords().size()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,212 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.api.transform.transform.integer.BaseIntegerTransform; |  | ||||||
| import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform; |  | ||||||
| import org.datavec.api.writable.DoubleWritable; |  | ||||||
| import org.datavec.api.writable.Text; |  | ||||||
| import org.datavec.api.writable.Writable; |  | ||||||
| import org.datavec.spark.inference.model.CSVSparkTransform; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.*; |  | ||||||
| 
 |  | ||||||
| public class CSVSparkTransformTest { |  | ||||||
|     @Test |  | ||||||
|     public void testTransformer() throws Exception { |  | ||||||
|         List<Writable> input = new ArrayList<>(); |  | ||||||
|         input.add(new DoubleWritable(1.0)); |  | ||||||
|         input.add(new DoubleWritable(2.0)); |  | ||||||
| 
 |  | ||||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|         List<Writable> output = new ArrayList<>(); |  | ||||||
|         output.add(new Text("1.0")); |  | ||||||
|         output.add(new Text("2.0")); |  | ||||||
| 
 |  | ||||||
|         TransformProcess transformProcess = |  | ||||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); |  | ||||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); |  | ||||||
|         Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values)); |  | ||||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
|         assertTrue(fromBase64.isVector()); |  | ||||||
| //        System.out.println("Base 64ed array " + fromBase64); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testTransformerBatch() throws Exception { |  | ||||||
|         List<Writable> input = new ArrayList<>(); |  | ||||||
|         input.add(new DoubleWritable(1.0)); |  | ||||||
|         input.add(new DoubleWritable(2.0)); |  | ||||||
| 
 |  | ||||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|         List<Writable> output = new ArrayList<>(); |  | ||||||
|         output.add(new Text("1.0")); |  | ||||||
|         output.add(new Text("2.0")); |  | ||||||
| 
 |  | ||||||
|         TransformProcess transformProcess = |  | ||||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); |  | ||||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < 3; i++) |  | ||||||
|             batchCSVRecord.add(record); |  | ||||||
|         //data type is string, unable to convert |  | ||||||
|         BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); |  | ||||||
|       /*  Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1); |  | ||||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
|         assertTrue(fromBase64.isMatrix()); |  | ||||||
|         System.out.println("Base 64ed array " + fromBase64); */ |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testSingleBatchSequence() throws Exception { |  | ||||||
|         List<Writable> input = new ArrayList<>(); |  | ||||||
|         input.add(new DoubleWritable(1.0)); |  | ||||||
|         input.add(new DoubleWritable(2.0)); |  | ||||||
| 
 |  | ||||||
|         Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|         List<Writable> output = new ArrayList<>(); |  | ||||||
|         output.add(new Text("1.0")); |  | ||||||
|         output.add(new Text("2.0")); |  | ||||||
| 
 |  | ||||||
|         TransformProcess transformProcess = |  | ||||||
|                 new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build(); |  | ||||||
|         CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess); |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values)); |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < 3; i++) |  | ||||||
|             batchCSVRecord.add(record); |  | ||||||
|         BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord); |  | ||||||
|         SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord(); |  | ||||||
|         sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); |  | ||||||
|         Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); |  | ||||||
|         INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray()); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|          //ensure accumulation |  | ||||||
|         sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord)); |  | ||||||
|         sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord); |  | ||||||
|         assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape()); |  | ||||||
| 
 |  | ||||||
|         SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord); |  | ||||||
|         assertNotNull(transformed.getRecords()); |  | ||||||
| //        System.out.println(transformed); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testSpecificSequence() throws Exception { |  | ||||||
|         final Schema schema = new Schema.Builder() |  | ||||||
|                 .addColumnsString("action") |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         final TransformProcess transformProcess = new TransformProcess.Builder(schema) |  | ||||||
|                 .removeAllColumnsExceptFor("action") |  | ||||||
|                 .transform(new ConverToLowercase("action")) |  | ||||||
|                 .convertToSequence() |  | ||||||
|                 .transform(new TextToCharacterIndexTransform("action", "action_sequence", |  | ||||||
|                         defaultCharIndex(), false)) |  | ||||||
|                 .integerToOneHot("action_sequence",0,29) |  | ||||||
|                 .build(); |  | ||||||
| 
 |  | ||||||
|         final String[] data1 = new String[] { "test1" }; |  | ||||||
|         final String[] data2 = new String[] { "test2" }; |  | ||||||
|         final BatchCSVRecord batchCsvRecord = new BatchCSVRecord( |  | ||||||
|                 Arrays.asList( |  | ||||||
|                         new SingleCSVRecord(data1), |  | ||||||
|                         new SingleCSVRecord(data2))); |  | ||||||
| 
 |  | ||||||
|         final CSVSparkTransform transform = new CSVSparkTransform(transformProcess); |  | ||||||
| //        System.out.println(transform.transformSequenceIncremental(batchCsvRecord)); |  | ||||||
|         transform.transformSequenceIncremental(batchCsvRecord); |  | ||||||
|         assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank()); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private static Map<Character,Integer> defaultCharIndex() { |  | ||||||
|         Map<Character,Integer> ret = new TreeMap<>(); |  | ||||||
| 
 |  | ||||||
|         ret.put('a',0); |  | ||||||
|         ret.put('b',1); |  | ||||||
|         ret.put('c',2); |  | ||||||
|         ret.put('d',3); |  | ||||||
|         ret.put('e',4); |  | ||||||
|         ret.put('f',5); |  | ||||||
|         ret.put('g',6); |  | ||||||
|         ret.put('h',7); |  | ||||||
|         ret.put('i',8); |  | ||||||
|         ret.put('j',9); |  | ||||||
|         ret.put('k',10); |  | ||||||
|         ret.put('l',11); |  | ||||||
|         ret.put('m',12); |  | ||||||
|         ret.put('n',13); |  | ||||||
|         ret.put('o',14); |  | ||||||
|         ret.put('p',15); |  | ||||||
|         ret.put('q',16); |  | ||||||
|         ret.put('r',17); |  | ||||||
|         ret.put('s',18); |  | ||||||
|         ret.put('t',19); |  | ||||||
|         ret.put('u',20); |  | ||||||
|         ret.put('v',21); |  | ||||||
|         ret.put('w',22); |  | ||||||
|         ret.put('x',23); |  | ||||||
|         ret.put('y',24); |  | ||||||
|         ret.put('z',25); |  | ||||||
|         ret.put('/',26); |  | ||||||
|         ret.put(' ',27); |  | ||||||
|         ret.put('(',28); |  | ||||||
|         ret.put(')',29); |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static class ConverToLowercase extends BaseIntegerTransform { |  | ||||||
|         public ConverToLowercase(String column) { |  | ||||||
|             super(column); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public Text map(Writable writable) { |  | ||||||
|             return new Text(writable.toString().toLowerCase()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public Object map(Object input) { |  | ||||||
|             return new Text(input.toString().toLowerCase()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,86 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.ImageSparkTransform; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; |  | ||||||
| import org.junit.Rule; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.junit.rules.TemporaryFolder; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class ImageSparkTransformTest { |  | ||||||
| 
 |  | ||||||
|     @Rule |  | ||||||
|     public TemporaryFolder testDir = new TemporaryFolder(); |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testSingleImageSparkTransform() throws Exception { |  | ||||||
|         int seed = 12345; |  | ||||||
| 
 |  | ||||||
|         File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); |  | ||||||
| 
 |  | ||||||
|         SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI()); |  | ||||||
| 
 |  | ||||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) |  | ||||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); |  | ||||||
| 
 |  | ||||||
|         ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); |  | ||||||
|         Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord); |  | ||||||
| 
 |  | ||||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
| //        System.out.println("Base 64ed array " + fromBase64); |  | ||||||
|         assertEquals(1, fromBase64.size(0)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testBatchImageSparkTransform() throws Exception { |  | ||||||
|         int seed = 12345; |  | ||||||
| 
 |  | ||||||
|         File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile(); |  | ||||||
|         File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile(); |  | ||||||
|         File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile(); |  | ||||||
| 
 |  | ||||||
|         BatchImageRecord batch = new BatchImageRecord(); |  | ||||||
|         batch.add(f0.toURI()); |  | ||||||
|         batch.add(f1.toURI()); |  | ||||||
|         batch.add(f2.toURI()); |  | ||||||
| 
 |  | ||||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed) |  | ||||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); |  | ||||||
| 
 |  | ||||||
|         ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess); |  | ||||||
|         Base64NDArrayBody body = imgSparkTransform.toArray(batch); |  | ||||||
| 
 |  | ||||||
|         INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray()); |  | ||||||
| //        System.out.println("Base 64ed array " + fromBase64); |  | ||||||
|         assertEquals(3, fromBase64.size(0)); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,60 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| import static org.junit.Assert.fail; |  | ||||||
| 
 |  | ||||||
| public class SingleCSVRecordTest { |  | ||||||
| 
 |  | ||||||
|     @Test(expected = IllegalArgumentException.class) |  | ||||||
|     public void testVectorAssertion() { |  | ||||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1)); |  | ||||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet); |  | ||||||
|         fail(singleCsvRecord.toString() + " should have thrown an exception"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testVectorOneHotLabel() { |  | ||||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}})); |  | ||||||
| 
 |  | ||||||
|         //assert |  | ||||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); |  | ||||||
|         assertEquals(3, singleCsvRecord.getValues().size()); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testVectorRegression() { |  | ||||||
|         DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}})); |  | ||||||
| 
 |  | ||||||
|         //assert |  | ||||||
|         SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0)); |  | ||||||
|         assertEquals(4, singleCsvRecord.getValues().size()); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,47 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; |  | ||||||
| import org.junit.Rule; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.junit.rules.TemporaryFolder; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| 
 |  | ||||||
| public class SingleImageRecordTest { |  | ||||||
| 
 |  | ||||||
|     @Rule |  | ||||||
|     public TemporaryFolder testDir = new TemporaryFolder(); |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testImageRecord() throws Exception { |  | ||||||
|         File f = testDir.newFolder(); |  | ||||||
|         new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f); |  | ||||||
|         File f0 = new File(f, "class0/0.jpg"); |  | ||||||
|         File f1 = new File(f, "/class1/A.jpg"); |  | ||||||
| 
 |  | ||||||
|         SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI()); |  | ||||||
| 
 |  | ||||||
|         // need jackson test? |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,154 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-spark-inference-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-spark-inference-server_2.11</artifactId> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-spark-inference-server</name> |  | ||||||
| 
 |  | ||||||
|     <properties> |  | ||||||
|         <!-- Default scala versions, may be overwritten by build profiles --> |  | ||||||
|         <scala.version>2.11.12</scala.version> |  | ||||||
|         <scala.binary.version>2.11</scala.binary.version> |  | ||||||
|         <maven.compiler.source>1.8</maven.compiler.source> |  | ||||||
|         <maven.compiler.target>1.8</maven.compiler.target> |  | ||||||
|     </properties> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark-inference-model</artifactId> |  | ||||||
|             <version>${datavec.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-spark_2.11</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.datavec</groupId> |  | ||||||
|             <artifactId>datavec-data-image</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>joda-time</groupId> |  | ||||||
|             <artifactId>joda-time</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.apache.commons</groupId> |  | ||||||
|             <artifactId>commons-lang3</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.hibernate</groupId> |  | ||||||
|             <artifactId>hibernate-validator</artifactId> |  | ||||||
|             <version>${hibernate.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.scala-lang</groupId> |  | ||||||
|             <artifactId>scala-library</artifactId> |  | ||||||
|             <version>${scala.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.scala-lang</groupId> |  | ||||||
|             <artifactId>scala-reflect</artifactId> |  | ||||||
|             <version>${scala.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.play</groupId> |  | ||||||
|             <artifactId>play-java_2.11</artifactId> |  | ||||||
|             <version>${playframework.version}</version> |  | ||||||
|             <exclusions> |  | ||||||
|                 <exclusion> |  | ||||||
|                     <groupId>com.google.code.findbugs</groupId> |  | ||||||
|                     <artifactId>jsr305</artifactId> |  | ||||||
|                 </exclusion> |  | ||||||
|                 <exclusion> |  | ||||||
|                     <groupId>net.jodah</groupId> |  | ||||||
|                     <artifactId>typetools</artifactId> |  | ||||||
|                 </exclusion> |  | ||||||
|             </exclusions> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>net.jodah</groupId> |  | ||||||
|             <artifactId>typetools</artifactId> |  | ||||||
|             <version>${jodah.typetools.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.play</groupId> |  | ||||||
|             <artifactId>play-json_2.11</artifactId> |  | ||||||
|             <version>${playframework.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.play</groupId> |  | ||||||
|             <artifactId>play-server_2.11</artifactId> |  | ||||||
|             <version>${playframework.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.play</groupId> |  | ||||||
|             <artifactId>play_2.11</artifactId> |  | ||||||
|             <version>${playframework.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.play</groupId> |  | ||||||
|             <artifactId>play-netty-server_2.11</artifactId> |  | ||||||
|             <version>${playframework.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.typesafe.akka</groupId> |  | ||||||
|             <artifactId>akka-cluster_2.11</artifactId> |  | ||||||
|             <version>2.5.23</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.mashape.unirest</groupId> |  | ||||||
|             <artifactId>unirest-java</artifactId> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.beust</groupId> |  | ||||||
|             <artifactId>jcommander</artifactId> |  | ||||||
|             <version>${jcommander.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.apache.spark</groupId> |  | ||||||
|             <artifactId>spark-core_2.11</artifactId> |  | ||||||
|             <version>${spark.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,352 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.server; |  | ||||||
| 
 |  | ||||||
| import com.beust.jcommander.JCommander; |  | ||||||
| import com.beust.jcommander.ParameterException; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.CSVSparkTransform; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| import play.BuiltInComponents; |  | ||||||
| import play.Mode; |  | ||||||
| import play.routing.Router; |  | ||||||
| import play.routing.RoutingDsl; |  | ||||||
| import play.server.Server; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.Base64; |  | ||||||
| import java.util.Random; |  | ||||||
| 
 |  | ||||||
| import static play.mvc.Results.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| @Data |  | ||||||
| public class CSVSparkTransformServer extends SparkTransformServer { |  | ||||||
|     private CSVSparkTransform transform; |  | ||||||
| 
 |  | ||||||
|     public void runMain(String[] args) throws Exception { |  | ||||||
|         JCommander jcmdr = new JCommander(this); |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             jcmdr.parse(args); |  | ||||||
|         } catch (ParameterException e) { |  | ||||||
|             //User provides invalid input -> print the usage info |  | ||||||
|             jcmdr.usage(); |  | ||||||
|             if (jsonPath == null) |  | ||||||
|                 System.err.println("Json path parameter is missing."); |  | ||||||
|             try { |  | ||||||
|                 Thread.sleep(500); |  | ||||||
|             } catch (Exception e2) { |  | ||||||
|             } |  | ||||||
|             System.exit(1); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (jsonPath != null) { |  | ||||||
|             String json = FileUtils.readFileToString(new File(jsonPath)); |  | ||||||
|             TransformProcess transformProcess = TransformProcess.fromJson(json); |  | ||||||
|             transform = new CSVSparkTransform(transformProcess); |  | ||||||
|         } else { |  | ||||||
|             log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" |  | ||||||
|                     + "to /transformprocess"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //Set play secret key, if required |  | ||||||
|         //http://www.playframework.com/documentation/latest/ApplicationSecret |  | ||||||
|         String crypto = System.getProperty("play.crypto.secret"); |  | ||||||
|         if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) { |  | ||||||
|             byte[] newCrypto = new byte[1024]; |  | ||||||
| 
 |  | ||||||
|             new Random().nextBytes(newCrypto); |  | ||||||
| 
 |  | ||||||
|             String base64 = Base64.getEncoder().encodeToString(newCrypto); |  | ||||||
|             System.setProperty("play.crypto.secret", base64); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         server = Server.forRouter(Mode.PROD, port, this::createRouter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected Router createRouter(BuiltInComponents b){ |  | ||||||
|         RoutingDsl routingDsl = RoutingDsl.fromComponents(b); |  | ||||||
| 
 |  | ||||||
|         routingDsl.GET("/transformprocess").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 if (transform == null) |  | ||||||
|                     return badRequest(); |  | ||||||
|                 return ok(transform.getTransformProcess().toJson()).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("Error in GET /transformprocess",e); |  | ||||||
|                 return internalServerError(e.getMessage()); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformprocess").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req)); |  | ||||||
|                 setCSVTransformProcess(transformProcess); |  | ||||||
|                 log.info("Transform process initialized"); |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("Error in POST /transformprocess",e); |  | ||||||
|                 return internalServerError(e.getMessage()); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformincremental").routingTo(req -> { |  | ||||||
|             if (isSequence(req)) { |  | ||||||
|                 try { |  | ||||||
|                     BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); |  | ||||||
|                     if (record == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformincremental", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 try { |  | ||||||
|                     SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); |  | ||||||
|                     if (record == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformincremental", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transform").routingTo(req -> { |  | ||||||
|             if (isSequence(req)) { |  | ||||||
|                 try { |  | ||||||
|                     SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class)); |  | ||||||
|                     if (batch == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(batch)).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transform", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 try { |  | ||||||
|                     BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); |  | ||||||
|                     BatchCSVRecord batch = transform(input); |  | ||||||
|                     if (batch == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(batch)).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transform", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformincrementalarray").routingTo(req -> { |  | ||||||
|             if (isSequence(req)) { |  | ||||||
|                 try { |  | ||||||
|                     BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); |  | ||||||
|                     if (record == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformincrementalarray", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 try { |  | ||||||
|                     SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class); |  | ||||||
|                     if (record == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformincrementalarray", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformarray").routingTo(req -> { |  | ||||||
|             if (isSequence(req)) { |  | ||||||
|                 try { |  | ||||||
|                     SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class); |  | ||||||
|                     if (batchCSVRecord == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformarray", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 try { |  | ||||||
|                     BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class); |  | ||||||
|                     if (batchCSVRecord == null) |  | ||||||
|                         return badRequest(); |  | ||||||
|                     return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     log.error("Error in /transformarray", e); |  | ||||||
|                     return internalServerError(e.getMessage()); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         return routingDsl.build(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void main(String[] args) throws Exception { |  | ||||||
|         new CSVSparkTransformServer().runMain(args); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transformProcess |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { |  | ||||||
|         this.transform = new CSVSparkTransform(transformProcess); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { |  | ||||||
|         log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass()); |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public TransformProcess getCSVTransformProcess() { |  | ||||||
|         return transform.getTransformProcess(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ImageTransformProcess getImageTransformProcess() { |  | ||||||
|         log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass()); |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      */ |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { |  | ||||||
|         return this.transform.transformSequenceIncremental(transform); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         return transform.transformSequence(batchCSVRecord); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         return this.transform.transformSequenceArray(batchCSVRecord); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { |  | ||||||
|         return this.transform.transformSequenceArrayIncremental(singleCsvRecord); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param transform |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord transform) { |  | ||||||
|         return this.transform.transform(transform); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         return this.transform.transform(batchCSVRecord); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         return transform.transform(batchCSVRecord); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param batchCSVRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         try { |  | ||||||
|             return this.transform.toArray(batchCSVRecord); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("Error in transformArray",e); |  | ||||||
|             throw new IllegalStateException("Transform array shouldn't throw exception"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @param singleCsvRecord |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { |  | ||||||
|         try { |  | ||||||
|             return this.transform.toArray(singleCsvRecord); |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             log.error("Error in transformArrayIncremental",e); |  | ||||||
|             throw new IllegalStateException("Transform array shouldn't throw exception"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException { |  | ||||||
|         log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass()); |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException { |  | ||||||
|         log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass()); |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,261 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.server; |  | ||||||
| 
 |  | ||||||
| import com.beust.jcommander.JCommander; |  | ||||||
| import com.beust.jcommander.ParameterException; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.model.ImageSparkTransform; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| import play.BuiltInComponents; |  | ||||||
| import play.Mode; |  | ||||||
| import play.libs.Files; |  | ||||||
| import play.mvc.Http; |  | ||||||
| import play.routing.Router; |  | ||||||
| import play.routing.RoutingDsl; |  | ||||||
| import play.server.Server; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| import static play.mvc.Results.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| @Data |  | ||||||
| public class ImageSparkTransformServer extends SparkTransformServer { |  | ||||||
|     private ImageSparkTransform transform; |  | ||||||
| 
 |  | ||||||
|     public void runMain(String[] args) throws Exception { |  | ||||||
|         JCommander jcmdr = new JCommander(this); |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             jcmdr.parse(args); |  | ||||||
|         } catch (ParameterException e) { |  | ||||||
|             //User provides invalid input -> print the usage info |  | ||||||
|             jcmdr.usage(); |  | ||||||
|             if (jsonPath == null) |  | ||||||
|                 System.err.println("Json path parameter is missing."); |  | ||||||
|             try { |  | ||||||
|                 Thread.sleep(500); |  | ||||||
|             } catch (Exception e2) { |  | ||||||
|             } |  | ||||||
|             System.exit(1); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (jsonPath != null) { |  | ||||||
|             String json = FileUtils.readFileToString(new File(jsonPath)); |  | ||||||
|             ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json); |  | ||||||
|             transform = new ImageSparkTransform(transformProcess); |  | ||||||
|         } else { |  | ||||||
|             log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json" |  | ||||||
|                     + "to /transformprocess"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         server = Server.forRouter(Mode.PROD, port, this::createRouter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected Router createRouter(BuiltInComponents builtInComponents){ |  | ||||||
|         RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents); |  | ||||||
| 
 |  | ||||||
|         routingDsl.GET("/transformprocess").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 if (transform == null) |  | ||||||
|                     return badRequest(); |  | ||||||
|                 log.info("Transform process initialized"); |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformprocess").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req)); |  | ||||||
|                 setImageTransformProcess(transformProcess); |  | ||||||
|                 log.info("Transform process initialized"); |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformincrementalarray").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class); |  | ||||||
|                 if (record == null) |  | ||||||
|                     return badRequest(); |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformincrementalimage").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData(); |  | ||||||
|                 List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles(); |  | ||||||
|                 if (files.isEmpty() || files.get(0).getRef() == null ) { |  | ||||||
|                     return badRequest(); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 File file = files.get(0).getRef().path().toFile(); |  | ||||||
|                 SingleImageRecord record = new SingleImageRecord(file.toURI()); |  | ||||||
| 
 |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformarray").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class); |  | ||||||
|                 if (batch == null) |  | ||||||
|                     return badRequest(); |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         routingDsl.POST("/transformimage").routingTo(req -> { |  | ||||||
|             try { |  | ||||||
|                 Http.MultipartFormData<Files.TemporaryFile> body = req.body().asMultipartFormData(); |  | ||||||
|                 List<Http.MultipartFormData.FilePart<Files.TemporaryFile>> files = body.getFiles(); |  | ||||||
|                 if (files.size() == 0) { |  | ||||||
|                     return badRequest(); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 List<SingleImageRecord> records = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|                 for (Http.MultipartFormData.FilePart<Files.TemporaryFile> filePart : files) { |  | ||||||
|                     Files.TemporaryFile file = filePart.getRef(); |  | ||||||
|                     if (file != null) { |  | ||||||
|                         SingleImageRecord record = new SingleImageRecord(file.path().toUri()); |  | ||||||
|                         records.add(record); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 BatchImageRecord batch = new BatchImageRecord(records); |  | ||||||
| 
 |  | ||||||
|                 return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 log.error("",e); |  | ||||||
|                 return internalServerError(); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         return routingDsl.build(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) { |  | ||||||
|         throw new UnsupportedOperationException(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setCSVTransformProcess(TransformProcess transformProcess) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) { |  | ||||||
|         this.transform = new ImageSparkTransform(imageTransformProcess); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public TransformProcess getCSVTransformProcess() { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public ImageTransformProcess getImageTransformProcess() { |  | ||||||
|         return transform.getImageTransformProcess(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) { |  | ||||||
|         throw new UnsupportedOperationException("Invalid operation for " + this.getClass()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException { |  | ||||||
|         return transform.toArray(record); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException { |  | ||||||
|         return transform.toArray(batch); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void main(String[] args) throws Exception { |  | ||||||
|         new ImageSparkTransformServer().runMain(args); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,67 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.server; |  | ||||||
| 
 |  | ||||||
| import com.beust.jcommander.Parameter; |  | ||||||
| import com.fasterxml.jackson.databind.JsonNode; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.service.DataVecTransformService; |  | ||||||
| import org.nd4j.shade.jackson.databind.ObjectMapper; |  | ||||||
| import play.mvc.Http; |  | ||||||
| import play.server.Server; |  | ||||||
| 
 |  | ||||||
| public abstract class SparkTransformServer implements DataVecTransformService { |  | ||||||
|     @Parameter(names = {"-j", "--jsonPath"}, arity = 1) |  | ||||||
|     protected String jsonPath = null; |  | ||||||
|     @Parameter(names = {"-dp", "--dataVecPort"}, arity = 1) |  | ||||||
|     protected int port = 9000; |  | ||||||
|     @Parameter(names = {"-dt", "--dataType"}, arity = 1) |  | ||||||
|     private TransformDataType transformDataType = null; |  | ||||||
|     protected Server server; |  | ||||||
|     protected static ObjectMapper objectMapper = new ObjectMapper(); |  | ||||||
|     protected static String contentType = "application/json"; |  | ||||||
| 
 |  | ||||||
|     public abstract void runMain(String[] args) throws Exception; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Stop the server |  | ||||||
|      */ |  | ||||||
|     public void stop() { |  | ||||||
|         if (server != null) |  | ||||||
|             server.stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected boolean isSequence(Http.Request request) { |  | ||||||
|         return request.hasHeader(SEQUENCE_OR_NOT_HEADER) |  | ||||||
|                 && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected String getJsonText(Http.Request request) { |  | ||||||
|         JsonNode tryJson = request.body().asJson(); |  | ||||||
|         if (tryJson != null) |  | ||||||
|             return tryJson.toString(); |  | ||||||
|         else |  | ||||||
|             return request.body().asText(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord); |  | ||||||
| } |  | ||||||
| @ -1,76 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.server; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| 
 |  | ||||||
| import java.io.InvalidClassException; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @Slf4j |  | ||||||
| public class SparkTransformServerChooser { |  | ||||||
|     private SparkTransformServer sparkTransformServer = null; |  | ||||||
|     private TransformDataType transformDataType = null; |  | ||||||
| 
 |  | ||||||
|     public void runMain(String[] args) throws Exception { |  | ||||||
| 
 |  | ||||||
|         int pos = getMatchingPosition(args, "-dt", "--dataType"); |  | ||||||
|         if (pos == -1) { |  | ||||||
|             log.error("no valid options"); |  | ||||||
|             log.error("-dt, --dataType   Options: [CSV, IMAGE]"); |  | ||||||
|             throw new Exception("no valid options"); |  | ||||||
|         } else { |  | ||||||
|             transformDataType = TransformDataType.valueOf(args[pos + 1]); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         switch (transformDataType) { |  | ||||||
|             case CSV: |  | ||||||
|                 sparkTransformServer = new CSVSparkTransformServer(); |  | ||||||
|                 break; |  | ||||||
|             case IMAGE: |  | ||||||
|                 sparkTransformServer = new ImageSparkTransformServer(); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 throw new InvalidClassException("no matching SparkTransform class"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         sparkTransformServer.runMain(args); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private int getMatchingPosition(String[] args, String... options) { |  | ||||||
|         List optionList = Arrays.asList(options); |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < args.length; i++) { |  | ||||||
|             if (optionList.contains(args[i])) { |  | ||||||
|                 return i; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return -1; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static void main(String[] args) throws Exception { |  | ||||||
|         new SparkTransformServerChooser().runMain(args); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,25 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.inference.server; |  | ||||||
| 
 |  | ||||||
| public enum TransformDataType { |  | ||||||
|     CSV, IMAGE, |  | ||||||
| } |  | ||||||
| @ -1,350 +0,0 @@ | |||||||
| # This is the main configuration file for the application. |  | ||||||
| # https://www.playframework.com/documentation/latest/ConfigFile |  | ||||||
| # ~~~~~ |  | ||||||
| # Play uses HOCON as its configuration file format.  HOCON has a number |  | ||||||
| # of advantages over other config formats, but there are two things that |  | ||||||
| # can be used when modifying settings. |  | ||||||
| # |  | ||||||
| # You can include other configuration files in this main application.conf file: |  | ||||||
| #include "extra-config.conf" |  | ||||||
| # |  | ||||||
| # You can declare variables and substitute for them: |  | ||||||
| #mykey = ${some.value} |  | ||||||
| # |  | ||||||
| # And if an environment variable exists when there is no other subsitution, then |  | ||||||
| # HOCON will fall back to substituting environment variable: |  | ||||||
| #mykey = ${JAVA_HOME} |  | ||||||
| 
 |  | ||||||
| ## Akka |  | ||||||
| # https://www.playframework.com/documentation/latest/ScalaAkka#Configuration |  | ||||||
| # https://www.playframework.com/documentation/latest/JavaAkka#Configuration |  | ||||||
| # ~~~~~ |  | ||||||
| # Play uses Akka internally and exposes Akka Streams and actors in Websockets and |  | ||||||
| # other streaming HTTP responses. |  | ||||||
| akka { |  | ||||||
|   # "akka.log-config-on-start" is extraordinarly useful because it log the complete |  | ||||||
|   # configuration at INFO level, including defaults and overrides, so it s worth |  | ||||||
|   # putting at the very top. |  | ||||||
|   # |  | ||||||
|   # Put the following in your conf/logback.xml file: |  | ||||||
|   # |  | ||||||
|   # <logger name="akka.actor" level="INFO" /> |  | ||||||
|   # |  | ||||||
|   # And then uncomment this line to debug the configuration. |  | ||||||
|   # |  | ||||||
|   #log-config-on-start = true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Modules |  | ||||||
| # https://www.playframework.com/documentation/latest/Modules |  | ||||||
| # ~~~~~ |  | ||||||
| # Control which modules are loaded when Play starts.  Note that modules are |  | ||||||
| # the replacement for "GlobalSettings", which are deprecated in 2.5.x. |  | ||||||
| # Please see https://www.playframework.com/documentation/latest/GlobalSettings |  | ||||||
| # for more information. |  | ||||||
| # |  | ||||||
| # You can also extend Play functionality by using one of the publically available |  | ||||||
| # Play modules: https://playframework.com/documentation/latest/ModuleDirectory |  | ||||||
| play.modules { |  | ||||||
|   # By default, Play will load any class called Module that is defined |  | ||||||
|   # in the root package (the "app" directory), or you can define them |  | ||||||
|   # explicitly below. |  | ||||||
|   # If there are any built-in modules that you want to disable, you can list them here. |  | ||||||
|   #enabled += my.application.Module |  | ||||||
| 
 |  | ||||||
|   # If there are any built-in modules that you want to disable, you can list them here. |  | ||||||
|   #disabled += "" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Internationalisation |  | ||||||
| # https://www.playframework.com/documentation/latest/JavaI18N |  | ||||||
| # https://www.playframework.com/documentation/latest/ScalaI18N |  | ||||||
| # ~~~~~ |  | ||||||
| # Play comes with its own i18n settings, which allow the user's preferred language |  | ||||||
| # to map through to internal messages, or allow the language to be stored in a cookie. |  | ||||||
| play.i18n { |  | ||||||
|   # The application languages |  | ||||||
|   langs = [ "en" ] |  | ||||||
| 
 |  | ||||||
|   # Whether the language cookie should be secure or not |  | ||||||
|   #langCookieSecure = true |  | ||||||
| 
 |  | ||||||
|   # Whether the HTTP only attribute of the cookie should be set to true |  | ||||||
|   #langCookieHttpOnly = true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Play HTTP settings |  | ||||||
| # ~~~~~ |  | ||||||
| play.http { |  | ||||||
|   ## Router |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaRouting |  | ||||||
|   # https://www.playframework.com/documentation/latest/ScalaRouting |  | ||||||
|   # ~~~~~ |  | ||||||
|   # Define the Router object to use for this application. |  | ||||||
|   # This router will be looked up first when the application is starting up, |  | ||||||
|   # so make sure this is the entry point. |  | ||||||
|   # Furthermore, it's assumed your route file is named properly. |  | ||||||
|   # So for an application router like `my.application.Router`, |  | ||||||
|   # you may need to define a router file `conf/my.application.routes`. |  | ||||||
|   # Default to Routes in the root package (aka "apps" folder) (and conf/routes) |  | ||||||
|   #router = my.application.Router |  | ||||||
| 
 |  | ||||||
|   ## Action Creator |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaActionCreator |  | ||||||
|   # ~~~~~ |  | ||||||
|   #actionCreator = null |  | ||||||
| 
 |  | ||||||
|   ## ErrorHandler |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaRouting |  | ||||||
|   # https://www.playframework.com/documentation/latest/ScalaRouting |  | ||||||
|   # ~~~~~ |  | ||||||
|   # If null, will attempt to load a class called ErrorHandler in the root package, |  | ||||||
|   #errorHandler = null |  | ||||||
| 
 |  | ||||||
|   ## Filters |  | ||||||
|   # https://www.playframework.com/documentation/latest/ScalaHttpFilters |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaHttpFilters |  | ||||||
|   # ~~~~~ |  | ||||||
|   # Filters run code on every request. They can be used to perform |  | ||||||
|   # common logic for all your actions, e.g. adding common headers. |  | ||||||
|   # Defaults to "Filters" in the root package (aka "apps" folder) |  | ||||||
|   # Alternatively you can explicitly register a class here. |  | ||||||
|   #filters += my.application.Filters |  | ||||||
| 
 |  | ||||||
|   ## Session & Flash |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaSessionFlash |  | ||||||
|   # https://www.playframework.com/documentation/latest/ScalaSessionFlash |  | ||||||
|   # ~~~~~ |  | ||||||
|   session { |  | ||||||
|     # Sets the cookie to be sent only over HTTPS. |  | ||||||
|     #secure = true |  | ||||||
| 
 |  | ||||||
|     # Sets the cookie to be accessed only by the server. |  | ||||||
|     #httpOnly = true |  | ||||||
| 
 |  | ||||||
|     # Sets the max-age field of the cookie to 5 minutes. |  | ||||||
|     # NOTE: this only sets when the browser will discard the cookie. Play will consider any |  | ||||||
|     # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout, |  | ||||||
|     # you need to put a timestamp in the session and check it at regular intervals to possibly expire it. |  | ||||||
|     #maxAge = 300 |  | ||||||
| 
 |  | ||||||
|     # Sets the domain on the session cookie. |  | ||||||
|     #domain = "example.com" |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   flash { |  | ||||||
|     # Sets the cookie to be sent only over HTTPS. |  | ||||||
|     #secure = true |  | ||||||
| 
 |  | ||||||
|     # Sets the cookie to be accessed only by the server. |  | ||||||
|     #httpOnly = true |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Netty Provider |  | ||||||
| # https://www.playframework.com/documentation/latest/SettingsNetty |  | ||||||
| # ~~~~~ |  | ||||||
| play.server.netty { |  | ||||||
|   # Whether the Netty wire should be logged |  | ||||||
|   #log.wire = true |  | ||||||
| 
 |  | ||||||
|   # If you run Play on Linux, you can use Netty's native socket transport |  | ||||||
|   # for higher performance with less garbage. |  | ||||||
|   #transport = "native" |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## WS (HTTP Client) |  | ||||||
| # https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS |  | ||||||
| # ~~~~~ |  | ||||||
| # The HTTP client primarily used for REST APIs.  The default client can be |  | ||||||
| # configured directly, but you can also create different client instances |  | ||||||
| # with customized settings. You must enable this by adding to build.sbt: |  | ||||||
| # |  | ||||||
| # libraryDependencies += ws // or javaWs if using java |  | ||||||
| # |  | ||||||
| play.ws { |  | ||||||
|   # Sets HTTP requests not to follow 302 requests |  | ||||||
|   #followRedirects = false |  | ||||||
| 
 |  | ||||||
|   # Sets the maximum number of open HTTP connections for the client. |  | ||||||
|   #ahc.maxConnectionsTotal = 50 |  | ||||||
| 
 |  | ||||||
|   ## WS SSL |  | ||||||
|   # https://www.playframework.com/documentation/latest/WsSSL |  | ||||||
|   # ~~~~~ |  | ||||||
|   ssl { |  | ||||||
|     # Configuring HTTPS with Play WS does not require programming.  You can |  | ||||||
|     # set up both trustManager and keyManager for mutual authentication, and |  | ||||||
|     # turn on JSSE debugging in development with a reload. |  | ||||||
|     #debug.handshake = true |  | ||||||
|     #trustManager = { |  | ||||||
|     #  stores = [ |  | ||||||
|     #    { type = "JKS", path = "exampletrust.jks" } |  | ||||||
|     #  ] |  | ||||||
|     #} |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Cache |  | ||||||
| # https://www.playframework.com/documentation/latest/JavaCache |  | ||||||
| # https://www.playframework.com/documentation/latest/ScalaCache |  | ||||||
| # ~~~~~ |  | ||||||
| # Play comes with an integrated cache API that can reduce the operational |  | ||||||
| # overhead of repeated requests. You must enable this by adding to build.sbt: |  | ||||||
| # |  | ||||||
| # libraryDependencies += cache |  | ||||||
| # |  | ||||||
| play.cache { |  | ||||||
|   # If you want to bind several caches, you can bind the individually |  | ||||||
|   #bindCaches = ["db-cache", "user-cache", "session-cache"] |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Filters |  | ||||||
| # https://www.playframework.com/documentation/latest/Filters |  | ||||||
| # ~~~~~ |  | ||||||
| # There are a number of built-in filters that can be enabled and configured |  | ||||||
| # to give Play greater security.  You must enable this by adding to build.sbt: |  | ||||||
| # |  | ||||||
| # libraryDependencies += filters |  | ||||||
| # |  | ||||||
| play.filters { |  | ||||||
|   ## CORS filter configuration |  | ||||||
|   # https://www.playframework.com/documentation/latest/CorsFilter |  | ||||||
|   # ~~~~~ |  | ||||||
|   # CORS is a protocol that allows web applications to make requests from the browser |  | ||||||
|   # across different domains. |  | ||||||
|   # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has |  | ||||||
|   # dependencies on CORS settings. |  | ||||||
|   cors { |  | ||||||
|     # Filter paths by a whitelist of path prefixes |  | ||||||
|     #pathPrefixes = ["/some/path", ...] |  | ||||||
| 
 |  | ||||||
|     # The allowed origins. If null, all origins are allowed. |  | ||||||
|     #allowedOrigins = ["http://www.example.com"] |  | ||||||
| 
 |  | ||||||
|     # The allowed HTTP methods. If null, all methods are allowed |  | ||||||
|     #allowedHttpMethods = ["GET", "POST"] |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   ## CSRF Filter |  | ||||||
|   # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter |  | ||||||
|   # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter |  | ||||||
|   # ~~~~~ |  | ||||||
|   # Play supports multiple methods for verifying that a request is not a CSRF request. |  | ||||||
|   # The primary mechanism is a CSRF token. This token gets placed either in the query string |  | ||||||
|   # or body of every form submitted, and also gets placed in the users session. |  | ||||||
|   # Play then verifies that both tokens are present and match. |  | ||||||
|   csrf { |  | ||||||
|     # Sets the cookie to be sent only over HTTPS |  | ||||||
|     #cookie.secure = true |  | ||||||
| 
 |  | ||||||
|     # Defaults to CSRFErrorHandler in the root package. |  | ||||||
|     #errorHandler = MyCSRFErrorHandler |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   ## Security headers filter configuration |  | ||||||
|   # https://www.playframework.com/documentation/latest/SecurityHeaders |  | ||||||
|   # ~~~~~ |  | ||||||
|   # Defines security headers that prevent XSS attacks. |  | ||||||
|   # If enabled, then all options are set to the below configuration by default: |  | ||||||
|   headers { |  | ||||||
|     # The X-Frame-Options header. If null, the header is not set. |  | ||||||
|     #frameOptions = "DENY" |  | ||||||
| 
 |  | ||||||
|     # The X-XSS-Protection header. If null, the header is not set. |  | ||||||
|     #xssProtection = "1; mode=block" |  | ||||||
| 
 |  | ||||||
|     # The X-Content-Type-Options header. If null, the header is not set. |  | ||||||
|     #contentTypeOptions = "nosniff" |  | ||||||
| 
 |  | ||||||
|     # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set. |  | ||||||
|     #permittedCrossDomainPolicies = "master-only" |  | ||||||
| 
 |  | ||||||
|     # The Content-Security-Policy header. If null, the header is not set. |  | ||||||
|     #contentSecurityPolicy = "default-src 'self'" |  | ||||||
|   } |  | ||||||
| 
 |  | ||||||
|   ## Allowed hosts filter configuration |  | ||||||
|   # https://www.playframework.com/documentation/latest/AllowedHostsFilter |  | ||||||
|   # ~~~~~ |  | ||||||
|   # Play provides a filter that lets you configure which hosts can access your application. |  | ||||||
|   # This is useful to prevent cache poisoning attacks. |  | ||||||
|   hosts { |  | ||||||
|     # Allow requests to example.com, its subdomains, and localhost:9000. |  | ||||||
|     #allowed = [".example.com", "localhost:9000"] |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Evolutions |  | ||||||
| # https://www.playframework.com/documentation/latest/Evolutions |  | ||||||
| # ~~~~~ |  | ||||||
| # Evolutions allows database scripts to be automatically run on startup in dev mode |  | ||||||
| # for database migrations. You must enable this by adding to build.sbt: |  | ||||||
| # |  | ||||||
| # libraryDependencies += evolutions |  | ||||||
| # |  | ||||||
| play.evolutions { |  | ||||||
|   # You can disable evolutions for a specific datasource if necessary |  | ||||||
|   #db.default.enabled = false |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## Database Connection Pool |  | ||||||
| # https://www.playframework.com/documentation/latest/SettingsJDBC |  | ||||||
| # ~~~~~ |  | ||||||
| # Play doesn't require a JDBC database to run, but you can easily enable one. |  | ||||||
| # |  | ||||||
| # libraryDependencies += jdbc |  | ||||||
| # |  | ||||||
| play.db { |  | ||||||
|   # The combination of these two settings results in "db.default" as the |  | ||||||
|   # default JDBC pool: |  | ||||||
|   #config = "db" |  | ||||||
|   #default = "default" |  | ||||||
| 
 |  | ||||||
|   # Play uses HikariCP as the default connection pool.  You can override |  | ||||||
|   # settings by changing the prototype: |  | ||||||
|   prototype { |  | ||||||
|     # Sets a fixed JDBC connection pool size of 50 |  | ||||||
|     #hikaricp.minimumIdle = 50 |  | ||||||
|     #hikaricp.maximumPoolSize = 50 |  | ||||||
|   } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| ## JDBC Datasource |  | ||||||
| # https://www.playframework.com/documentation/latest/JavaDatabase |  | ||||||
| # https://www.playframework.com/documentation/latest/ScalaDatabase |  | ||||||
| # ~~~~~ |  | ||||||
| # Once JDBC datasource is set up, you can work with several different |  | ||||||
| # database options: |  | ||||||
| # |  | ||||||
| # Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick |  | ||||||
| # JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA |  | ||||||
| # EBean: https://playframework.com/documentation/latest/JavaEbean |  | ||||||
| # Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm |  | ||||||
| # |  | ||||||
| db { |  | ||||||
|   # You can declare as many datasources as you want. |  | ||||||
|   # By convention, the default datasource is named `default` |  | ||||||
| 
 |  | ||||||
|   # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database |  | ||||||
|   default.driver = org.h2.Driver |  | ||||||
|   default.url = "jdbc:h2:mem:play" |  | ||||||
|   #default.username = sa |  | ||||||
|   #default.password = "" |  | ||||||
| 
 |  | ||||||
|   # You can expose this datasource via JNDI if needed (Useful for JPA) |  | ||||||
|   default.jndiName=DefaultDS |  | ||||||
| 
 |  | ||||||
|   # You can turn on SQL logging for any datasource |  | ||||||
|   # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements |  | ||||||
|   #default.logSql=true |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| jpa.default=defaultPersistenceUnit |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| #Increase default maximum post length - used for remote listener functionality |  | ||||||
| #Can get response 413 with larger networks without setting this |  | ||||||
| # parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead |  | ||||||
| #parsers.text.maxLength=10M |  | ||||||
| play.http.parser.maxMemoryBuffer=10M |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.nd4j.common.tests.AbstractAssertTestsClass; |  | ||||||
| import org.nd4j.common.tests.BaseND4JTest; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     protected Set<Class<?>> getExclusions() { |  | ||||||
| 	    //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts) |  | ||||||
| 	    return new HashSet<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
| 	protected String getPackageName() { |  | ||||||
|     	return "org.datavec.spark.transform"; |  | ||||||
| 	} |  | ||||||
| 
 |  | ||||||
| 	@Override |  | ||||||
| 	protected Class<?> getBaseClass() { |  | ||||||
|     	return BaseND4JTest.class; |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @ -1,127 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.JsonNode; |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Test; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertTrue; |  | ||||||
| import static org.junit.Assume.assumeNotNull; |  | ||||||
| 
 |  | ||||||
| public class CSVSparkTransformServerNoJsonTest { |  | ||||||
| 
 |  | ||||||
|     private static CSVSparkTransformServer server; |  | ||||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|     private static TransformProcess transformProcess = |  | ||||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); |  | ||||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void before() throws Exception { |  | ||||||
|         server = new CSVSparkTransformServer(); |  | ||||||
|         FileUtils.write(fileSave, transformProcess.toJson()); |  | ||||||
| 
 |  | ||||||
|         // Only one time |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                     new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         server.runMain(new String[] {"-dp", "9050"}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void after() throws Exception { |  | ||||||
|         fileSave.delete(); |  | ||||||
|         server.stop(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testServer() throws Exception { |  | ||||||
|         assertTrue(server.getTransform() == null); |  | ||||||
|         JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(transformProcess.toJson()).asJson().getBody(); |  | ||||||
|         assumeNotNull(server.getTransform()); |  | ||||||
| 
 |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = new SingleCSVRecord(values); |  | ||||||
|         JsonNode jsonNode = |  | ||||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") |  | ||||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); |  | ||||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(SingleCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < 3; i++) |  | ||||||
|             batchCSVRecord.add(singleCsvRecord); |  | ||||||
|     /*    BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| */ |  | ||||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,121 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.JsonNode; |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.spark.inference.server.CSVSparkTransformServer; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchCSVRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleCSVRecord; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Test; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| public class CSVSparkTransformServerTest { |  | ||||||
| 
 |  | ||||||
|     private static CSVSparkTransformServer server; |  | ||||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|     private static TransformProcess transformProcess = |  | ||||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build(); |  | ||||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void before() throws Exception { |  | ||||||
|         server = new CSVSparkTransformServer(); |  | ||||||
|         FileUtils.write(fileSave, transformProcess.toJson()); |  | ||||||
|         // Only one time |  | ||||||
| 
 |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void after() throws Exception { |  | ||||||
|         fileSave.deleteOnExit(); |  | ||||||
|         server.stop(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testServer() throws Exception { |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = new SingleCSVRecord(values); |  | ||||||
|         JsonNode jsonNode = |  | ||||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") |  | ||||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); |  | ||||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(SingleCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < 3; i++) |  | ||||||
|             batchCSVRecord.add(singleCsvRecord); |  | ||||||
|         BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,164 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.JsonNode; |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.server.ImageSparkTransformServer; |  | ||||||
| import org.datavec.spark.inference.model.model.Base64NDArrayBody; |  | ||||||
| import org.datavec.spark.inference.model.model.BatchImageRecord; |  | ||||||
| import org.datavec.spark.inference.model.model.SingleImageRecord; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Rule; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.junit.rules.TemporaryFolder; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class ImageSparkTransformServerTest { |  | ||||||
| 
 |  | ||||||
|     @Rule |  | ||||||
|     public TemporaryFolder testDir = new TemporaryFolder(); |  | ||||||
| 
 |  | ||||||
|     private static ImageSparkTransformServer server; |  | ||||||
|     private static File fileSave = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void before() throws Exception { |  | ||||||
|         server = new ImageSparkTransformServer(); |  | ||||||
| 
 |  | ||||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) |  | ||||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); |  | ||||||
| 
 |  | ||||||
|         FileUtils.write(fileSave, imgTransformProcess.toJson()); |  | ||||||
| 
 |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"}); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void after() throws Exception { |  | ||||||
|         fileSave.deleteOnExit(); |  | ||||||
|         server.stop(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testImageServer() throws Exception { |  | ||||||
|         SingleImageRecord record = |  | ||||||
|                         new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); |  | ||||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asJson().getBody(); |  | ||||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         BatchImageRecord batch = new BatchImageRecord(); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); |  | ||||||
| 
 |  | ||||||
|         JsonNode jsonNodeBatch = |  | ||||||
|                         Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") |  | ||||||
|                                         .header("Content-Type", "application/json").body(batch).asJson().getBody(); |  | ||||||
|         Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(batch) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         INDArray result = getNDArray(jsonNode); |  | ||||||
|         assertEquals(1, result.size(0)); |  | ||||||
| 
 |  | ||||||
|         INDArray batchResult = getNDArray(jsonNodeBatch); |  | ||||||
|         assertEquals(3, batchResult.size(0)); |  | ||||||
| 
 |  | ||||||
| //        System.out.println(array); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testImageServerMultipart() throws Exception { |  | ||||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") |  | ||||||
|                 .header("accept", "application/json") |  | ||||||
|                 .field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile()) |  | ||||||
|                 .field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile()) |  | ||||||
|                 .field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile()) |  | ||||||
|                 .asJson().getBody(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         INDArray batchResult = getNDArray(jsonNode); |  | ||||||
|         assertEquals(3, batchResult.size(0)); |  | ||||||
| 
 |  | ||||||
| //        System.out.println(batchResult); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testImageServerSingleMultipart() throws Exception { |  | ||||||
|         File f = testDir.newFolder(); |  | ||||||
|         File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f); |  | ||||||
| 
 |  | ||||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage") |  | ||||||
|                 .header("accept", "application/json") |  | ||||||
|                 .field("file1", imgFile) |  | ||||||
|                 .asJson().getBody(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         INDArray result = getNDArray(jsonNode); |  | ||||||
|         assertEquals(1, result.size(0)); |  | ||||||
| 
 |  | ||||||
| //        System.out.println(result); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getNDArray(JsonNode node) throws IOException { |  | ||||||
|         return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,168 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.datavec.spark.transform; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.JsonNode; |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.datavec.api.transform.TransformProcess; |  | ||||||
| import org.datavec.api.transform.schema.Schema; |  | ||||||
| import org.datavec.image.transform.ImageTransformProcess; |  | ||||||
| import org.datavec.spark.inference.server.SparkTransformServerChooser; |  | ||||||
| import org.datavec.spark.inference.server.TransformDataType; |  | ||||||
| import org.datavec.spark.inference.model.model.*; |  | ||||||
| import org.junit.AfterClass; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.common.io.ClassPathResource; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class SparkTransformServerTest { |  | ||||||
|     private static SparkTransformServerChooser serverChooser; |  | ||||||
|     private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build(); |  | ||||||
|     private static TransformProcess transformProcess = |  | ||||||
|                     new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble(    "2.0").build(); |  | ||||||
| 
 |  | ||||||
|     private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
|     private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json"); |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void before() throws Exception { |  | ||||||
|         serverChooser = new SparkTransformServerChooser(); |  | ||||||
| 
 |  | ||||||
|         ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345) |  | ||||||
|                         .scaleImageTransform(10).cropImageTransform(5).build(); |  | ||||||
| 
 |  | ||||||
|         FileUtils.write(imageTransformFile, imgTransformProcess.toJson()); |  | ||||||
| 
 |  | ||||||
|         FileUtils.write(csvTransformFile, transformProcess.toJson()); |  | ||||||
| 
 |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (Exception e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @AfterClass |  | ||||||
|     public static void after() throws Exception { |  | ||||||
|         imageTransformFile.deleteOnExit(); |  | ||||||
|         csvTransformFile.deleteOnExit(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testImageServer() throws Exception { |  | ||||||
|         serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt", |  | ||||||
|                         TransformDataType.IMAGE.toString()}); |  | ||||||
| 
 |  | ||||||
|         SingleImageRecord record = |  | ||||||
|                         new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); |  | ||||||
|         JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asJson().getBody(); |  | ||||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         BatchImageRecord batch = new BatchImageRecord(); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI()); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI()); |  | ||||||
|         batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI()); |  | ||||||
| 
 |  | ||||||
|         JsonNode jsonNodeBatch = |  | ||||||
|                         Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json") |  | ||||||
|                                         .header("Content-Type", "application/json").body(batch).asJson().getBody(); |  | ||||||
|         Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(batch) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         INDArray result = getNDArray(jsonNode); |  | ||||||
|         assertEquals(1, result.size(0)); |  | ||||||
| 
 |  | ||||||
|         INDArray batchResult = getNDArray(jsonNodeBatch); |  | ||||||
|         assertEquals(3, batchResult.size(0)); |  | ||||||
| 
 |  | ||||||
|         serverChooser.getSparkTransformServer().stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testCSVServer() throws Exception { |  | ||||||
|         serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt", |  | ||||||
|                         TransformDataType.CSV.toString()}); |  | ||||||
| 
 |  | ||||||
|         String[] values = new String[] {"1.0", "2.0"}; |  | ||||||
|         SingleCSVRecord record = new SingleCSVRecord(values); |  | ||||||
|         JsonNode jsonNode = |  | ||||||
|                         Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json") |  | ||||||
|                                         .header("Content-Type", "application/json").body(record).asJson().getBody(); |  | ||||||
|         SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(SingleCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         BatchCSVRecord batchCSVRecord = new BatchCSVRecord(); |  | ||||||
|         for (int i = 0; i < 3; i++) |  | ||||||
|             batchCSVRecord.add(singleCsvRecord); |  | ||||||
|         BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json").body(record) |  | ||||||
|                         .asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray") |  | ||||||
|                         .header("accept", "application/json").header("Content-Type", "application/json") |  | ||||||
|                         .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         serverChooser.getSparkTransformServer().stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getNDArray(JsonNode node) throws IOException { |  | ||||||
|         return Nd4jBase64.fromBase64(node.getObject().getString("ndarray")); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,6 +0,0 @@ | |||||||
| play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule |  | ||||||
| play.modules.enabled += io.skymind.skil.service.PredictionModule |  | ||||||
| play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk |  | ||||||
| play.server.pidfile.path=/tmp/RUNNING_PID |  | ||||||
| 
 |  | ||||||
| play.server.http.port = 9600 |  | ||||||
| @ -1,68 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.datavec</groupId> |  | ||||||
|         <artifactId>datavec-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>datavec-spark-inference-parent</artifactId> |  | ||||||
|     <packaging>pom</packaging> |  | ||||||
| 
 |  | ||||||
|     <name>datavec-spark-inference-parent</name> |  | ||||||
| 
 |  | ||||||
|     <modules> |  | ||||||
|         <module>datavec-spark-inference-server</module> |  | ||||||
|         <module>datavec-spark-inference-client</module> |  | ||||||
|         <module>datavec-spark-inference-model</module> |  | ||||||
|     </modules> |  | ||||||
| 
 |  | ||||||
|     <dependencyManagement> |  | ||||||
|         <dependencies> |  | ||||||
|             <dependency> |  | ||||||
|                 <groupId>org.datavec</groupId> |  | ||||||
|                 <artifactId>datavec-data-image</artifactId> |  | ||||||
|                 <version>${datavec.version}</version> |  | ||||||
|             </dependency> |  | ||||||
|             <dependency> |  | ||||||
|                 <groupId>com.mashape.unirest</groupId> |  | ||||||
|                 <artifactId>unirest-java</artifactId> |  | ||||||
|                 <version>${unirest.version}</version> |  | ||||||
|             </dependency> |  | ||||||
|         </dependencies> |  | ||||||
|     </dependencyManagement> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -45,7 +45,6 @@ | |||||||
|         <module>datavec-data</module> |         <module>datavec-data</module> | ||||||
|         <module>datavec-spark</module> |         <module>datavec-spark</module> | ||||||
|         <module>datavec-local</module> |         <module>datavec-local</module> | ||||||
|         <module>datavec-spark-inference-parent</module> |  | ||||||
|         <module>datavec-jdbc</module> |         <module>datavec-jdbc</module> | ||||||
|         <module>datavec-excel</module> |         <module>datavec-excel</module> | ||||||
|         <module>datavec-arrow</module> |         <module>datavec-arrow</module> | ||||||
|  | |||||||
| @ -1,143 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.deeplearning4j</groupId> |  | ||||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>deeplearning4j-nearestneighbor-server</artifactId> |  | ||||||
|     <packaging>jar</packaging> |  | ||||||
| 
 |  | ||||||
|     <name>deeplearning4j-nearestneighbor-server</name> |  | ||||||
| 
 |  | ||||||
|     <properties> |  | ||||||
|         <java.compile.version>1.8</java.compile.version> |  | ||||||
|     </properties> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-nearestneighbors-model</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-core</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>io.vertx</groupId> |  | ||||||
|             <artifactId>vertx-core</artifactId> |  | ||||||
|             <version>${vertx.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>io.vertx</groupId> |  | ||||||
|             <artifactId>vertx-web</artifactId> |  | ||||||
|             <version>${vertx.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.mashape.unirest</groupId> |  | ||||||
|             <artifactId>unirest-java</artifactId> |  | ||||||
|             <version>${unirest.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-nearestneighbors-client</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.beust</groupId> |  | ||||||
|             <artifactId>jcommander</artifactId> |  | ||||||
|             <version>${jcommander.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>ch.qos.logback</groupId> |  | ||||||
|             <artifactId>logback-classic</artifactId> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-common-tests</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <build> |  | ||||||
|         <plugins> |  | ||||||
|             <plugin> |  | ||||||
|                 <groupId>org.apache.maven.plugins</groupId> |  | ||||||
|                 <artifactId>maven-surefire-plugin</artifactId> |  | ||||||
|                 <configuration> |  | ||||||
|                     <argLine>-Dfile.encoding=UTF-8 -Xmx8g</argLine> |  | ||||||
|                     <includes> |  | ||||||
|                         <!-- Default setting only runs tests that start/end with "Test" --> |  | ||||||
|                         <include>*.java</include> |  | ||||||
|                         <include>**/*.java</include> |  | ||||||
|                     </includes> |  | ||||||
|                 </configuration> |  | ||||||
|             </plugin> |  | ||||||
|             <plugin> |  | ||||||
|                 <groupId>org.apache.maven.plugins</groupId> |  | ||||||
|                 <artifactId>maven-compiler-plugin</artifactId> |  | ||||||
|                 <configuration> |  | ||||||
|                     <source>${java.compile.version}</source> |  | ||||||
|                     <target>${java.compile.version}</target> |  | ||||||
|                 </configuration> |  | ||||||
|             </plugin> |  | ||||||
|         </plugins> |  | ||||||
|     </build> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|             <dependencies> |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.nd4j</groupId> |  | ||||||
|                     <artifactId>nd4j-native</artifactId> |  | ||||||
|                     <version>${project.version}</version> |  | ||||||
|                     <scope>test</scope> |  | ||||||
|                 </dependency> |  | ||||||
|             </dependencies> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|             <dependencies> |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.nd4j</groupId> |  | ||||||
|                     <artifactId>nd4j-cuda-11.0</artifactId> |  | ||||||
|                     <version>${project.version}</version> |  | ||||||
|                     <scope>test</scope> |  | ||||||
|                 </dependency> |  | ||||||
|             </dependencies> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,67 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.server; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import org.deeplearning4j.clustering.sptree.DataPoint; |  | ||||||
| import org.deeplearning4j.clustering.vptree.VPTree; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Builder |  | ||||||
| public class NearestNeighbor { |  | ||||||
|     private NearestNeighborRequest record; |  | ||||||
|     private VPTree tree; |  | ||||||
|     private INDArray points; |  | ||||||
| 
 |  | ||||||
|     public List<NearestNeighborsResult> search() { |  | ||||||
|         INDArray input = points.slice(record.getInputIndex()); |  | ||||||
|         List<NearestNeighborsResult> results = new ArrayList<>(); |  | ||||||
|         if (input.isVector()) { |  | ||||||
|             List<DataPoint> add = new ArrayList<>(); |  | ||||||
|             List<Double> distances = new ArrayList<>(); |  | ||||||
|             tree.search(input, record.getK(), add, distances); |  | ||||||
| 
 |  | ||||||
|             if (add.size() != distances.size()) { |  | ||||||
|                 throw new IllegalStateException( |  | ||||||
|                         String.format("add.size == %d != %d == distances.size", |  | ||||||
|                                 add.size(), distances.size())); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             for (int i=0; i<add.size(); i++) { |  | ||||||
|                 results.add(new NearestNeighborsResult(add.get(i).getIndex(), distances.get(i))); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return results; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,278 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.server; |  | ||||||
| 
 |  | ||||||
| import com.beust.jcommander.JCommander; |  | ||||||
| import com.beust.jcommander.Parameter; |  | ||||||
| import com.beust.jcommander.ParameterException; |  | ||||||
| import io.netty.handler.codec.http.HttpResponseStatus; |  | ||||||
| import io.vertx.core.AbstractVerticle; |  | ||||||
| import io.vertx.core.Vertx; |  | ||||||
| import io.vertx.ext.web.Router; |  | ||||||
| import io.vertx.ext.web.handler.BodyHandler; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.apache.commons.io.FileUtils; |  | ||||||
| import org.deeplearning4j.clustering.sptree.DataPoint; |  | ||||||
| import org.deeplearning4j.clustering.vptree.VPTree; |  | ||||||
| import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; |  | ||||||
| import org.deeplearning4j.exception.DL4JInvalidInputException; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.*; |  | ||||||
| import org.deeplearning4j.nn.conf.serde.JsonMappers; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataBuffer; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.shape.Shape; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.linalg.indexing.NDArrayIndex; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| import org.nd4j.serde.binary.BinarySerde; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| public class NearestNeighborsServer extends AbstractVerticle { |  | ||||||
| 
 |  | ||||||
|     private static class RunArgs { |  | ||||||
|         @Parameter(names = {"--ndarrayPath"}, arity = 1, required = true) |  | ||||||
|         private String ndarrayPath = null; |  | ||||||
|         @Parameter(names = {"--labelsPath"}, arity = 1, required = false) |  | ||||||
|         private String labelsPath = null; |  | ||||||
|         @Parameter(names = {"--nearestNeighborsPort"}, arity = 1) |  | ||||||
|         private int port = 9000; |  | ||||||
|         @Parameter(names = {"--similarityFunction"}, arity = 1) |  | ||||||
|         private String similarityFunction = "euclidean"; |  | ||||||
|         @Parameter(names = {"--invert"}, arity = 1) |  | ||||||
|         private boolean invert = false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private static RunArgs instanceArgs; |  | ||||||
|     private static NearestNeighborsServer instance; |  | ||||||
| 
 |  | ||||||
|     public NearestNeighborsServer(){ } |  | ||||||
| 
 |  | ||||||
|     public static NearestNeighborsServer getInstance(){ |  | ||||||
|         return instance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void runMain(String... args) { |  | ||||||
|         RunArgs r = new RunArgs(); |  | ||||||
|         JCommander jcmdr = new JCommander(r); |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             jcmdr.parse(args); |  | ||||||
|         } catch (ParameterException e) { |  | ||||||
|             log.error("Error in NearestNeighboursServer parameters", e); |  | ||||||
|             StringBuilder sb = new StringBuilder(); |  | ||||||
|             jcmdr.usage(sb); |  | ||||||
|             log.error("Usage: {}", sb.toString()); |  | ||||||
| 
 |  | ||||||
|             //User provides invalid input -> print the usage info |  | ||||||
|             jcmdr.usage(); |  | ||||||
|             if (r.ndarrayPath == null) |  | ||||||
|                 log.error("Json path parameter is missing (null)"); |  | ||||||
|             try { |  | ||||||
|                 Thread.sleep(500); |  | ||||||
|             } catch (Exception e2) { |  | ||||||
|             } |  | ||||||
|             System.exit(1); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         instanceArgs = r; |  | ||||||
|         try { |  | ||||||
|             Vertx vertx = Vertx.vertx(); |  | ||||||
|             vertx.deployVerticle(NearestNeighborsServer.class.getName()); |  | ||||||
|         } catch (Throwable t){ |  | ||||||
|             log.error("Error in NearestNeighboursServer run method",t); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public void start() throws Exception { |  | ||||||
|         instance = this; |  | ||||||
| 
 |  | ||||||
|         String[] pathArr = instanceArgs.ndarrayPath.split(","); |  | ||||||
|         //INDArray[] pointsArr = new INDArray[pathArr.length]; |  | ||||||
|         // first of all we reading shapes of saved eariler files |  | ||||||
|         int rows = 0; |  | ||||||
|         int cols = 0; |  | ||||||
|         for (int i = 0; i < pathArr.length; i++) { |  | ||||||
|             DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i])); |  | ||||||
| 
 |  | ||||||
|             log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0), |  | ||||||
|                     Shape.size(shape, 1)); |  | ||||||
| 
 |  | ||||||
|             if (Shape.rank(shape) != 2) |  | ||||||
|                 throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks"); |  | ||||||
| 
 |  | ||||||
|             rows += Shape.size(shape, 0); |  | ||||||
| 
 |  | ||||||
|             if (cols == 0) |  | ||||||
|                 cols = Shape.size(shape, 1); |  | ||||||
|             else if (cols != Shape.size(shape, 1)) |  | ||||||
|                 throw new DL4JInvalidInputException( |  | ||||||
|                         "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch."); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         final List<String> labels = new ArrayList<>(); |  | ||||||
|         if (instanceArgs.labelsPath != null) { |  | ||||||
|             String[] labelsPathArr = instanceArgs.labelsPath.split(","); |  | ||||||
|             for (int i = 0; i < labelsPathArr.length; i++) { |  | ||||||
|                 labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8")); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         if (!labels.isEmpty() && labels.size() != rows) |  | ||||||
|             throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size())); |  | ||||||
| 
 |  | ||||||
|         final INDArray points = Nd4j.createUninitialized(rows, cols); |  | ||||||
| 
 |  | ||||||
|         int lastPosition = 0; |  | ||||||
|         for (int i = 0; i < pathArr.length; i++) { |  | ||||||
|             log.info("Loading chunk {} of {}", i + 1, pathArr.length); |  | ||||||
|             INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i])); |  | ||||||
| 
 |  | ||||||
|             points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr); |  | ||||||
|             lastPosition += pointsArr.rows(); |  | ||||||
| 
 |  | ||||||
|             // let's ensure we don't bring too much stuff in next loop |  | ||||||
|             System.gc(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert); |  | ||||||
| 
 |  | ||||||
|         //Set play secret key, if required |  | ||||||
|         //http://www.playframework.com/documentation/latest/ApplicationSecret |  | ||||||
|         String crypto = System.getProperty("play.crypto.secret"); |  | ||||||
|         if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) { |  | ||||||
|             byte[] newCrypto = new byte[1024]; |  | ||||||
| 
 |  | ||||||
|             new Random().nextBytes(newCrypto); |  | ||||||
| 
 |  | ||||||
|             String base64 = Base64.getEncoder().encodeToString(newCrypto); |  | ||||||
|             System.setProperty("play.crypto.secret", base64); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Router r = Router.router(vertx); |  | ||||||
|         r.route().handler(BodyHandler.create());  //NOTE: Setting this is required to receive request body content at all |  | ||||||
|         createRoutes(r, labels, tree, points); |  | ||||||
| 
 |  | ||||||
|         vertx.createHttpServer() |  | ||||||
|                 .requestHandler(r) |  | ||||||
|                 .listen(instanceArgs.port); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void createRoutes(Router r, List<String> labels, VPTree tree, INDArray points){ |  | ||||||
| 
 |  | ||||||
|         r.post("/knn").handler(rc -> { |  | ||||||
|             try { |  | ||||||
|                 String json = rc.getBodyAsJson().encode(); |  | ||||||
|                 NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class); |  | ||||||
| 
 |  | ||||||
|                 NearestNeighbor nearestNeighbor = |  | ||||||
|                         NearestNeighbor.builder().points(points).record(record).tree(tree).build(); |  | ||||||
| 
 |  | ||||||
|                 if (record == null) { |  | ||||||
|                     rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) |  | ||||||
|                             .putHeader("content-type", "application/json") |  | ||||||
|                             .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); |  | ||||||
|                     return; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build(); |  | ||||||
| 
 |  | ||||||
|                 rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) |  | ||||||
|                         .putHeader("content-type", "application/json") |  | ||||||
|                         .end(JsonMappers.getMapper().writeValueAsString(results)); |  | ||||||
|                 return; |  | ||||||
|             } catch (Throwable e) { |  | ||||||
|                 log.error("Error in POST /knn",e); |  | ||||||
|                 rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) |  | ||||||
|                         .end("Error parsing request - " + e.getMessage()); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         r.post("/knnnew").handler(rc -> { |  | ||||||
|             try { |  | ||||||
|                 String json = rc.getBodyAsJson().encode(); |  | ||||||
|                 Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class); |  | ||||||
|                 if (record == null) { |  | ||||||
|                     rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code()) |  | ||||||
|                             .putHeader("content-type", "application/json") |  | ||||||
|                             .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed."))); |  | ||||||
|                     return; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 INDArray arr = Nd4jBase64.fromBase64(record.getNdarray()); |  | ||||||
|                 List<DataPoint> results; |  | ||||||
|                 List<Double> distances; |  | ||||||
| 
 |  | ||||||
|                 if (record.isForceFillK()) { |  | ||||||
|                     VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr); |  | ||||||
|                     vpTreeFillSearch.search(); |  | ||||||
|                     results = vpTreeFillSearch.getResults(); |  | ||||||
|                     distances = vpTreeFillSearch.getDistances(); |  | ||||||
|                 } else { |  | ||||||
|                     results = new ArrayList<>(); |  | ||||||
|                     distances = new ArrayList<>(); |  | ||||||
|                     tree.search(arr, record.getK(), results, distances); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (results.size() != distances.size()) { |  | ||||||
|                     rc.response() |  | ||||||
|                             .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) |  | ||||||
|                             .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size())); |  | ||||||
|                     return; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 List<NearestNeighborsResult> nnResult = new ArrayList<>(); |  | ||||||
|                 for (int i=0; i<results.size(); i++) { |  | ||||||
|                     if (!labels.isEmpty()) |  | ||||||
|                         nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i), labels.get(results.get(i).getIndex()))); |  | ||||||
|                     else |  | ||||||
|                         nnResult.add(new NearestNeighborsResult(results.get(i).getIndex(), distances.get(i))); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 NearestNeighborsResults results2 = NearestNeighborsResults.builder().results(nnResult).build(); |  | ||||||
|                 String j = JsonMappers.getMapper().writeValueAsString(results2); |  | ||||||
|                 rc.response() |  | ||||||
|                         .putHeader("content-type", "application/json") |  | ||||||
|                         .end(j); |  | ||||||
|             } catch (Throwable e) { |  | ||||||
|                 log.error("Error in POST /knnnew",e); |  | ||||||
|                 rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code()) |  | ||||||
|                         .end("Error parsing request - " + e.getMessage()); |  | ||||||
|                 return; |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Stop the server |  | ||||||
|      */ |  | ||||||
|     public void stop() throws Exception { |  | ||||||
|         super.stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void main(String[] args) throws Exception { |  | ||||||
|         runMain(args); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,161 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.server; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.BaseDL4JTest; |  | ||||||
| import org.deeplearning4j.clustering.sptree.DataPoint; |  | ||||||
| import org.deeplearning4j.clustering.vptree.VPTree; |  | ||||||
| import org.deeplearning4j.clustering.vptree.VPTreeFillSearch; |  | ||||||
| import org.deeplearning4j.nearestneighbor.client.NearestNeighborsClient; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResults; |  | ||||||
| import org.junit.Rule; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.junit.rules.TemporaryFolder; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.serde.binary.BinarySerde; |  | ||||||
| 
 |  | ||||||
| import java.io.File; |  | ||||||
| import java.io.IOException; |  | ||||||
| import java.net.ServerSocket; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.concurrent.Executor; |  | ||||||
| import java.util.concurrent.Executors; |  | ||||||
| 
 |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| 
 |  | ||||||
| public class NearestNeighborTest extends BaseDL4JTest { |  | ||||||
| 
 |  | ||||||
|     @Rule |  | ||||||
|     public TemporaryFolder testDir = new TemporaryFolder(); |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testNearestNeighbor() { |  | ||||||
|         double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; |  | ||||||
|         INDArray arr = Nd4j.create(data); |  | ||||||
| 
 |  | ||||||
|         VPTree vpTree = new VPTree(arr, false); |  | ||||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); |  | ||||||
|         request.setK(2); |  | ||||||
|         request.setInputIndex(0); |  | ||||||
|         NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); |  | ||||||
|         List<NearestNeighborsResult> results = nearestNeighbor.search(); |  | ||||||
|         assertEquals(1, results.get(0).getIndex()); |  | ||||||
|         assertEquals(2, results.size()); |  | ||||||
| 
 |  | ||||||
|         assertEquals(1.0, results.get(0).getDistance(), 1e-4); |  | ||||||
|         assertEquals(4.0, results.get(1).getDistance(), 1e-4); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testNearestNeighborInverted() { |  | ||||||
|         double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}}; |  | ||||||
|         INDArray arr = Nd4j.create(data); |  | ||||||
| 
 |  | ||||||
|         VPTree vpTree = new VPTree(arr, true); |  | ||||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); |  | ||||||
|         request.setK(2); |  | ||||||
|         request.setInputIndex(0); |  | ||||||
|         NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build(); |  | ||||||
|         List<NearestNeighborsResult> results = nearestNeighbor.search(); |  | ||||||
|         assertEquals(2, results.get(0).getIndex()); |  | ||||||
|         assertEquals(2, results.size()); |  | ||||||
| 
 |  | ||||||
|         assertEquals(-4.0, results.get(0).getDistance(), 1e-4); |  | ||||||
|         assertEquals(-1.0, results.get(1).getDistance(), 1e-4); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void vpTreeTest() throws Exception { |  | ||||||
|         INDArray matrix = Nd4j.rand(new int[] {400,10}); |  | ||||||
|         INDArray rowVector = matrix.getRow(70); |  | ||||||
|         INDArray resultArr = Nd4j.zeros(400,1); |  | ||||||
|         Executor executor = Executors.newSingleThreadExecutor(); |  | ||||||
|         VPTree vpTree = new VPTree(matrix); |  | ||||||
|         System.out.println("Ran!"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public static int getAvailablePort() { |  | ||||||
|         try { |  | ||||||
|             ServerSocket socket = new ServerSocket(0); |  | ||||||
|             try { |  | ||||||
|                 return socket.getLocalPort(); |  | ||||||
|             } finally { |  | ||||||
|                 socket.close(); |  | ||||||
|             } |  | ||||||
|         } catch (IOException e) { |  | ||||||
|             throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testServer() throws Exception { |  | ||||||
|         int localPort = getAvailablePort(); |  | ||||||
|         Nd4j.getRandom().setSeed(7); |  | ||||||
|         INDArray rand = Nd4j.randn(10, 5); |  | ||||||
|         File writeToTmp = testDir.newFile(); |  | ||||||
|         writeToTmp.deleteOnExit(); |  | ||||||
|         BinarySerde.writeArrayToDisk(rand, writeToTmp); |  | ||||||
|         NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort", |  | ||||||
|                 String.valueOf(localPort)); |  | ||||||
| 
 |  | ||||||
|         Thread.sleep(3000); |  | ||||||
| 
 |  | ||||||
|         NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort); |  | ||||||
|         NearestNeighborsResults result = client.knnNew(5, rand.getRow(0)); |  | ||||||
|         assertEquals(5, result.getResults().size()); |  | ||||||
|         NearestNeighborsServer.getInstance().stop(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testFullSearch() throws Exception { |  | ||||||
|         int numRows = 1000; |  | ||||||
|         int numCols = 100; |  | ||||||
|         int numNeighbors = 42; |  | ||||||
|         INDArray points = Nd4j.rand(numRows, numCols); |  | ||||||
|         VPTree tree = new VPTree(points); |  | ||||||
|         INDArray query = Nd4j.rand(new int[] {1, numCols}); |  | ||||||
|         VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query); |  | ||||||
|         fillSearch.search(); |  | ||||||
|         List<DataPoint> results = fillSearch.getResults(); |  | ||||||
|         List<Double> distances = fillSearch.getDistances(); |  | ||||||
|         assertEquals(numNeighbors, distances.size()); |  | ||||||
|         assertEquals(numNeighbors, results.size()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testDistances() { |  | ||||||
| 
 |  | ||||||
|         INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}}); |  | ||||||
|         INDArray record = Nd4j.create(new float[][]{{7, 6}}); |  | ||||||
|         VPTree vpTree = new VPTree(indArray, "euclidean", false); |  | ||||||
|         VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record); |  | ||||||
|         vpTreeFillSearch.search(); |  | ||||||
|         //System.out.println(vpTreeFillSearch.getResults()); |  | ||||||
|         System.out.println(vpTreeFillSearch.getDistances()); |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <configuration> |  | ||||||
| 
 |  | ||||||
|     <appender name="FILE" class="ch.qos.logback.core.FileAppender"> |  | ||||||
|         <file>logs/application.log</file> |  | ||||||
|         <encoder> |  | ||||||
|             <pattern> %logger{15} - %message%n%xException{5} |  | ||||||
|             </pattern> |  | ||||||
|         </encoder> |  | ||||||
|     </appender> |  | ||||||
| 
 |  | ||||||
|     <appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender"> |  | ||||||
|         <encoder> |  | ||||||
|             <pattern> %logger{15} - %message%n%xException{5} |  | ||||||
|             </pattern> |  | ||||||
|         </encoder> |  | ||||||
|     </appender> |  | ||||||
| 
 |  | ||||||
|     <logger name="org.deeplearning4j" level="INFO" /> |  | ||||||
|     <logger name="org.datavec" level="INFO" /> |  | ||||||
|     <logger name="org.nd4j" level="INFO" /> |  | ||||||
| 
 |  | ||||||
|     <root level="ERROR"> |  | ||||||
|         <appender-ref ref="STDOUT" /> |  | ||||||
|         <appender-ref ref="FILE" /> |  | ||||||
|     </root> |  | ||||||
| </configuration> |  | ||||||
| @ -1,60 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.deeplearning4j</groupId> |  | ||||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>deeplearning4j-nearestneighbors-client</artifactId> |  | ||||||
|     <packaging>jar</packaging> |  | ||||||
| 
 |  | ||||||
|     <name>deeplearning4j-nearestneighbors-client</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>com.mashape.unirest</groupId> |  | ||||||
|             <artifactId>unirest-java</artifactId> |  | ||||||
|             <version>${unirest.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-nearestneighbors-model</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,137 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.client; |  | ||||||
| 
 |  | ||||||
| import com.mashape.unirest.http.ObjectMapper; |  | ||||||
| import com.mashape.unirest.http.Unirest; |  | ||||||
| import com.mashape.unirest.request.HttpRequest; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.Setter; |  | ||||||
| import lombok.val; |  | ||||||
| import org.deeplearning4j.nearestneighbor.model.*; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.serde.base64.Nd4jBase64; |  | ||||||
| import org.nd4j.shade.jackson.core.JsonProcessingException; |  | ||||||
| 
 |  | ||||||
| import java.io.IOException; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class NearestNeighborsClient { |  | ||||||
| 
 |  | ||||||
|     private String url; |  | ||||||
|     @Setter |  | ||||||
|     @Getter |  | ||||||
|     protected String authToken; |  | ||||||
| 
 |  | ||||||
|     public NearestNeighborsClient(String url){ |  | ||||||
|         this(url, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     static { |  | ||||||
|         // Only one time |  | ||||||
| 
 |  | ||||||
|         Unirest.setObjectMapper(new ObjectMapper() { |  | ||||||
|             private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper = |  | ||||||
|                             new org.nd4j.shade.jackson.databind.ObjectMapper(); |  | ||||||
| 
 |  | ||||||
|             public <T> T readValue(String value, Class<T> valueType) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.readValue(value, valueType); |  | ||||||
|                 } catch (IOException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             public String writeValue(Object value) { |  | ||||||
|                 try { |  | ||||||
|                     return jacksonObjectMapper.writeValueAsString(value); |  | ||||||
|                 } catch (JsonProcessingException e) { |  | ||||||
|                     throw new RuntimeException(e); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Runs knn on the given index |  | ||||||
|      * with the given k (note that this is for data |  | ||||||
|      * already within the existing dataset not new data) |  | ||||||
|      * @param index the index of the |  | ||||||
|      *              EXISTING ndarray |  | ||||||
|      *              to run a search on |  | ||||||
|      * @param k the number of results |  | ||||||
|      * @return |  | ||||||
|      * @throws Exception |  | ||||||
|      */ |  | ||||||
|     public NearestNeighborsResults knn(int index, int k) throws Exception { |  | ||||||
|         NearestNeighborRequest request = new NearestNeighborRequest(); |  | ||||||
|         request.setInputIndex(index); |  | ||||||
|         request.setK(k); |  | ||||||
|         val req = Unirest.post(url + "/knn"); |  | ||||||
|         req.header("accept", "application/json") |  | ||||||
|                 .header("Content-Type", "application/json").body(request); |  | ||||||
|         addAuthHeader(req); |  | ||||||
| 
 |  | ||||||
|         NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Run a k nearest neighbors search |  | ||||||
|      * on a NEW data point |  | ||||||
|      * @param k the number of results |  | ||||||
|      *          to retrieve |  | ||||||
|      * @param arr the array to run the search on. |  | ||||||
|      *            Note that this must be a row vector |  | ||||||
|      * @return |  | ||||||
|      * @throws Exception |  | ||||||
|      */ |  | ||||||
|     public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception { |  | ||||||
|         Base64NDArrayBody base64NDArrayBody = |  | ||||||
|                         Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build(); |  | ||||||
| 
 |  | ||||||
|         val req = Unirest.post(url + "/knnnew"); |  | ||||||
|         req.header("accept", "application/json") |  | ||||||
|                 .header("Content-Type", "application/json").body(base64NDArrayBody); |  | ||||||
|         addAuthHeader(req); |  | ||||||
| 
 |  | ||||||
|         NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody(); |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add the specified authentication header to the specified HttpRequest |  | ||||||
|      * |  | ||||||
|      * @param request HTTP Request to add the authentication header to |  | ||||||
|      */ |  | ||||||
|     protected HttpRequest addAuthHeader(HttpRequest request) { |  | ||||||
|         if (authToken != null) { |  | ||||||
|             request.header("authorization", "Bearer " + authToken); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return request; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.deeplearning4j</groupId> |  | ||||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>deeplearning4j-nearestneighbors-model</artifactId> |  | ||||||
|     <packaging>jar</packaging> |  | ||||||
| 
 |  | ||||||
|     <name>deeplearning4j-nearestneighbors-model</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.projectlombok</groupId> |  | ||||||
|             <artifactId>lombok</artifactId> |  | ||||||
|             <version>${lombok.version}</version> |  | ||||||
|             <scope>provided</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.nd4j</groupId> |  | ||||||
|             <artifactId>nd4j-api</artifactId> |  | ||||||
|             <version>${nd4j.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,38 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| @Builder |  | ||||||
| public class Base64NDArrayBody implements Serializable { |  | ||||||
|     private String ndarray; |  | ||||||
|     private int k; |  | ||||||
|     private boolean forceFillK; |  | ||||||
| } |  | ||||||
| @ -1,65 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @Builder |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class BatchRecord implements Serializable { |  | ||||||
|     private List<CSVRecord> records; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a record |  | ||||||
|      * @param record |  | ||||||
|      */ |  | ||||||
|     public void add(CSVRecord record) { |  | ||||||
|         if (records == null) |  | ||||||
|             records = new ArrayList<>(); |  | ||||||
|         records.add(record); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Return a batch record based on a dataset |  | ||||||
|      * @param dataSet the dataset to get the batch record for |  | ||||||
|      * @return the batch record |  | ||||||
|      */ |  | ||||||
|     public static BatchRecord fromDataSet(DataSet dataSet) { |  | ||||||
|         BatchRecord batchRecord = new BatchRecord(); |  | ||||||
|         for (int i = 0; i < dataSet.numExamples(); i++) { |  | ||||||
|             batchRecord.add(CSVRecord.fromRow(dataSet.get(i))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return batchRecord; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,85 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.nd4j.linalg.dataset.DataSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class CSVRecord implements Serializable { |  | ||||||
|     private String[] values; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Instantiate a csv record from a vector |  | ||||||
|      * given either an input dataset and a |  | ||||||
|      * one hot matrix, the index will be appended to |  | ||||||
|      * the end of the record, or for regression |  | ||||||
|      * it will append all values in the labels |  | ||||||
|      * @param row the input vectors |  | ||||||
|      * @return the record from this {@link DataSet} |  | ||||||
|      */ |  | ||||||
|     public static CSVRecord fromRow(DataSet row) { |  | ||||||
|         if (!row.getFeatures().isVector() && !row.getFeatures().isScalar()) |  | ||||||
|             throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector"); |  | ||||||
|         if (!row.getLabels().isVector() && !row.getLabels().isScalar()) |  | ||||||
|             throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector"); |  | ||||||
|         //classification |  | ||||||
|         CSVRecord record; |  | ||||||
|         int idx = 0; |  | ||||||
|         if (row.getLabels().sumNumber().doubleValue() == 1.0) { |  | ||||||
|             String[] values = new String[row.getFeatures().columns() + 1]; |  | ||||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); |  | ||||||
|             } |  | ||||||
|             int maxIdx = 0; |  | ||||||
|             for (int i = 0; i < row.getLabels().length(); i++) { |  | ||||||
|                 if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) { |  | ||||||
|                     maxIdx = i; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             values[idx++] = String.valueOf(maxIdx); |  | ||||||
|             record = new CSVRecord(values); |  | ||||||
|         } |  | ||||||
|         //regression (any number of values) |  | ||||||
|         else { |  | ||||||
|             String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()]; |  | ||||||
|             for (int i = 0; i < row.getFeatures().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getFeatures().getDouble(i)); |  | ||||||
|             } |  | ||||||
|             for (int i = 0; i < row.getLabels().length(); i++) { |  | ||||||
|                 values[idx++] = String.valueOf(row.getLabels().getDouble(i)); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             record = new CSVRecord(values); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|         return record; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,32 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class NearestNeighborRequest implements Serializable { |  | ||||||
|     private int k; |  | ||||||
|     private int inputIndex; |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,37 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| @Data |  | ||||||
| @AllArgsConstructor |  | ||||||
| @NoArgsConstructor |  | ||||||
| public class NearestNeighborsResult { |  | ||||||
|     public NearestNeighborsResult(int index, double distance) { |  | ||||||
|         this(index, distance, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private int index; |  | ||||||
|     private double distance; |  | ||||||
|     private String label; |  | ||||||
| } |  | ||||||
| @ -1,38 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.nearestneighbor.model; |  | ||||||
| 
 |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @Builder |  | ||||||
| @NoArgsConstructor |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class NearestNeighborsResults implements Serializable { |  | ||||||
|     private List<NearestNeighborsResult> results; |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,103 +0,0 @@ | |||||||
| <?xml version="1.0" encoding="UTF-8"?> |  | ||||||
| <!-- |  | ||||||
|   ~ /* ****************************************************************************** |  | ||||||
|   ~  * |  | ||||||
|   ~  * |  | ||||||
|   ~  * This program and the accompanying materials are made available under the |  | ||||||
|   ~  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|   ~  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|   ~  * |  | ||||||
|   ~  *  See the NOTICE file distributed with this work for additional |  | ||||||
|   ~  *  information regarding copyright ownership. |  | ||||||
|   ~  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|   ~  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|   ~  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|   ~  * License for the specific language governing permissions and limitations |  | ||||||
|   ~  * under the License. |  | ||||||
|   ~  * |  | ||||||
|   ~  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|   ~  ******************************************************************************/ |  | ||||||
|   --> |  | ||||||
| 
 |  | ||||||
| <project xmlns="http://maven.apache.org/POM/4.0.0" |  | ||||||
|     xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" |  | ||||||
|     xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> |  | ||||||
| 
 |  | ||||||
|     <modelVersion>4.0.0</modelVersion> |  | ||||||
| 
 |  | ||||||
|     <parent> |  | ||||||
|         <groupId>org.deeplearning4j</groupId> |  | ||||||
|         <artifactId>deeplearning4j-nearestneighbors-parent</artifactId> |  | ||||||
|         <version>1.0.0-SNAPSHOT</version> |  | ||||||
|     </parent> |  | ||||||
| 
 |  | ||||||
|     <artifactId>nearestneighbor-core</artifactId> |  | ||||||
|     <packaging>jar</packaging> |  | ||||||
| 
 |  | ||||||
|     <name>nearestneighbor-core</name> |  | ||||||
| 
 |  | ||||||
|     <dependencies> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.nd4j</groupId> |  | ||||||
|             <artifactId>nd4j-api</artifactId> |  | ||||||
|             <version>${nd4j.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>junit</groupId> |  | ||||||
|             <artifactId>junit</artifactId> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>ch.qos.logback</groupId> |  | ||||||
|             <artifactId>logback-classic</artifactId> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-nn</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-datasets</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>joda-time</groupId> |  | ||||||
|             <artifactId>joda-time</artifactId> |  | ||||||
|             <version>2.10.3</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|         <dependency> |  | ||||||
|             <groupId>org.deeplearning4j</groupId> |  | ||||||
|             <artifactId>deeplearning4j-common-tests</artifactId> |  | ||||||
|             <version>${project.version}</version> |  | ||||||
|             <scope>test</scope> |  | ||||||
|         </dependency> |  | ||||||
|     </dependencies> |  | ||||||
| 
 |  | ||||||
|     <profiles> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-native</id> |  | ||||||
|             <dependencies> |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.nd4j</groupId> |  | ||||||
|                     <artifactId>nd4j-native</artifactId> |  | ||||||
|                     <version>${project.version}</version> |  | ||||||
|                     <scope>test</scope> |  | ||||||
|                 </dependency> |  | ||||||
|             </dependencies> |  | ||||||
|         </profile> |  | ||||||
|         <profile> |  | ||||||
|             <id>test-nd4j-cuda-11.0</id> |  | ||||||
|             <dependencies> |  | ||||||
|                 <dependency> |  | ||||||
|                     <groupId>org.nd4j</groupId> |  | ||||||
|                     <artifactId>nd4j-cuda-11.0</artifactId> |  | ||||||
|                     <version>${project.version}</version> |  | ||||||
|                     <scope>test</scope> |  | ||||||
|                 </dependency> |  | ||||||
|             </dependencies> |  | ||||||
|         </profile> |  | ||||||
|     </profiles> |  | ||||||
| </project> |  | ||||||
| @ -1,218 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.algorithm; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import lombok.val; |  | ||||||
| import org.apache.commons.lang3.ArrayUtils; |  | ||||||
| import org.deeplearning4j.clustering.cluster.Cluster; |  | ||||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; |  | ||||||
| import org.deeplearning4j.clustering.cluster.ClusterUtils; |  | ||||||
| import org.deeplearning4j.clustering.cluster.Point; |  | ||||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationInfo; |  | ||||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategy; |  | ||||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategyType; |  | ||||||
| import org.deeplearning4j.clustering.strategy.OptimisationStrategy; |  | ||||||
| import org.deeplearning4j.clustering.util.MultiThreadUtils; |  | ||||||
| import org.nd4j.common.base.Preconditions; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.concurrent.ExecutorService; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable { |  | ||||||
| 
 |  | ||||||
|     private static final long serialVersionUID = 338231277453149972L; |  | ||||||
| 
 |  | ||||||
|     private ClusteringStrategy clusteringStrategy; |  | ||||||
|     private IterationHistory iterationHistory; |  | ||||||
|     private int currentIteration = 0; |  | ||||||
|     private ClusterSet clusterSet; |  | ||||||
|     private List<Point> initialPoints; |  | ||||||
|     private transient ExecutorService exec; |  | ||||||
|     private boolean useKmeansPlusPlus; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { |  | ||||||
|         this.clusteringStrategy = clusteringStrategy; |  | ||||||
|         this.exec = MultiThreadUtils.newExecutorService(); |  | ||||||
|         this.useKmeansPlusPlus = useKmeansPlusPlus; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusteringStrategy |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) { |  | ||||||
|         return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param points |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public ClusterSet applyTo(List<Point> points) { |  | ||||||
|         resetState(points); |  | ||||||
|         initClusters(useKmeansPlusPlus); |  | ||||||
|         iterations(); |  | ||||||
|         return clusterSet; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void resetState(List<Point> points) { |  | ||||||
|         this.iterationHistory = new IterationHistory(); |  | ||||||
|         this.currentIteration = 0; |  | ||||||
|         this.clusterSet = null; |  | ||||||
|         this.initialPoints = points; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** Run clustering iterations until a |  | ||||||
|      * termination condition is hit. |  | ||||||
|      * This is done by first classifying all points, |  | ||||||
|      * and then updating cluster centers based on |  | ||||||
|      * those classified points |  | ||||||
|      */ |  | ||||||
|     private void iterations() { |  | ||||||
|         int iterationCount = 0; |  | ||||||
|         while ((clusteringStrategy.getTerminationCondition() != null |  | ||||||
|                         && !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory)) |  | ||||||
|                         || iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) { |  | ||||||
|             currentIteration++; |  | ||||||
|             removePoints(); |  | ||||||
|             classifyPoints(); |  | ||||||
|             applyClusteringStrategy(); |  | ||||||
|             log.trace("Completed clustering iteration {}", ++iterationCount); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected void classifyPoints() { |  | ||||||
|         //Classify points. This also adds each point to the ClusterSet |  | ||||||
|         ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec); |  | ||||||
|         //Update the cluster centers, based on the points within each cluster |  | ||||||
|         ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec); |  | ||||||
|         iterationHistory.getIterationsInfos().put(currentIteration, |  | ||||||
|                         new IterationInfo(currentIteration, clusterSetInfo)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initialize the |  | ||||||
|      * cluster centers at random |  | ||||||
|      */ |  | ||||||
|     protected void initClusters(boolean kMeansPlusPlus) { |  | ||||||
|         log.info("Generating initial clusters"); |  | ||||||
|         List<Point> points = new ArrayList<>(initialPoints); |  | ||||||
| 
 |  | ||||||
|         //Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly) |  | ||||||
|         val random = Nd4j.getRandom(); |  | ||||||
|         Distance distanceFn = clusteringStrategy.getDistanceFunction(); |  | ||||||
|         int initialClusterCount = clusteringStrategy.getInitialClusterCount(); |  | ||||||
|         clusterSet = new ClusterSet(distanceFn, |  | ||||||
|                         clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()}); |  | ||||||
|         clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size()))); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //dxs: distances between |  | ||||||
|         // each point and nearest cluster to that point |  | ||||||
|         INDArray dxs = Nd4j.create(points.size()); |  | ||||||
|         dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE); |  | ||||||
| 
 |  | ||||||
|         //Generate the initial cluster centers, by randomly selecting a point between 0 and max distance |  | ||||||
|         //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster |  | ||||||
|         while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) { |  | ||||||
|             dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec); |  | ||||||
|             double summed = Nd4j.sum(dxs).getDouble(0); |  | ||||||
|             double r = kMeansPlusPlus ? random.nextDouble() * summed: |  | ||||||
|                                         random.nextFloat() * dxs.maxNumber().doubleValue(); |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < dxs.length(); i++) { |  | ||||||
|                 double distance = dxs.getDouble(i); |  | ||||||
|                 Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " + |  | ||||||
|                         "function must return values >= 0, got distance %s for function s", distance, distanceFn); |  | ||||||
|                 if (dxs.getDouble(i) >= r) { |  | ||||||
|                     clusterSet.addNewClusterWithCenter(points.remove(i)); |  | ||||||
|                     dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i)); |  | ||||||
|                     break; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet); |  | ||||||
|         iterationHistory.getIterationsInfos().put(currentIteration, |  | ||||||
|                         new IterationInfo(currentIteration, initialClusterSetInfo)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected void applyClusteringStrategy() { |  | ||||||
|         if (!isStrategyApplicableNow()) |  | ||||||
|             return; |  | ||||||
| 
 |  | ||||||
|         ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); |  | ||||||
|         if (!clusteringStrategy.isAllowEmptyClusters()) { |  | ||||||
|             int removedCount = removeEmptyClusters(clusterSetInfo); |  | ||||||
|             if (removedCount > 0) { |  | ||||||
|                 iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); |  | ||||||
| 
 |  | ||||||
|                 if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT) |  | ||||||
|                                 && clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) { |  | ||||||
|                     int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo, |  | ||||||
|                                     clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec); |  | ||||||
|                     if (splitCount > 0) |  | ||||||
|                         iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true); |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION)) |  | ||||||
|             optimize(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected void optimize() { |  | ||||||
|         ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo(); |  | ||||||
|         OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy; |  | ||||||
|         boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec); |  | ||||||
|         iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean isStrategyApplicableNow() { |  | ||||||
|         return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0 |  | ||||||
|                         && clusteringStrategy.isOptimizationApplicableNow(iterationHistory); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) { |  | ||||||
|         List<Cluster> removedClusters = clusterSet.removeEmptyClusters(); |  | ||||||
|         clusterSetInfo.removeClusterInfos(removedClusters); |  | ||||||
|         return removedClusters.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected void removePoints() { |  | ||||||
|         clusterSet.removePoints(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,38 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.algorithm; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; |  | ||||||
| import org.deeplearning4j.clustering.cluster.Point; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public interface ClusteringAlgorithm { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Apply a clustering |  | ||||||
|      * algorithm for a given result |  | ||||||
|      * @param points |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     ClusterSet applyTo(List<Point> points); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,41 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.algorithm; |  | ||||||
| 
 |  | ||||||
| public enum Distance { |  | ||||||
|     EUCLIDEAN("euclidean"), |  | ||||||
|     COSINE_DISTANCE("cosinedistance"), |  | ||||||
|     COSINE_SIMILARITY("cosinesimilarity"), |  | ||||||
|     MANHATTAN("manhattan"), |  | ||||||
|     DOT("dot"), |  | ||||||
|     JACCARD("jaccard"), |  | ||||||
|     HAMMING("hamming"); |  | ||||||
| 
 |  | ||||||
|     private String functionName; |  | ||||||
|     private Distance(String name) { |  | ||||||
|         functionName = name; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         return functionName; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,105 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.ReduceOp; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| public class CentersHolder { |  | ||||||
|     private INDArray centers; |  | ||||||
|     private long index = 0; |  | ||||||
| 
 |  | ||||||
|     protected transient ReduceOp op; |  | ||||||
|     protected ArgMin imin; |  | ||||||
|     protected transient INDArray distances; |  | ||||||
|     protected transient INDArray argMin; |  | ||||||
| 
 |  | ||||||
|     private long rows, cols; |  | ||||||
| 
 |  | ||||||
|     public CentersHolder(long rows, long cols) { |  | ||||||
|         this.rows = rows; |  | ||||||
|         this.cols = cols; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getCenters() { |  | ||||||
|         return this.centers; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public synchronized void addCenter(INDArray pointView) { |  | ||||||
|         if (centers == null) |  | ||||||
|             this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols}); |  | ||||||
| 
 |  | ||||||
|         centers.putRow(index++, pointView); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public synchronized Pair<Double, Long> getCenterByMinDistance(Point point, Distance distanceFunction) { |  | ||||||
|         if (distances == null) |  | ||||||
|             distances = Nd4j.create(centers.dataType(), centers.rows()); |  | ||||||
| 
 |  | ||||||
|         if (argMin == null) |  | ||||||
|             argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); |  | ||||||
| 
 |  | ||||||
|         if (op == null) { |  | ||||||
|             op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); |  | ||||||
|             imin = new ArgMin(distances, argMin); |  | ||||||
|             op.setZ(distances); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         op.setY(point.getArray()); |  | ||||||
| 
 |  | ||||||
|         Nd4j.getExecutioner().exec(op); |  | ||||||
|         Nd4j.getExecutioner().exec(imin); |  | ||||||
| 
 |  | ||||||
|         Pair<Double, Long> result = new Pair<>(); |  | ||||||
|         result.setFirst(distances.getDouble(argMin.getLong(0))); |  | ||||||
|         result.setSecond(argMin.getLong(0)); |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) { |  | ||||||
|         if (distances == null) |  | ||||||
|             distances = Nd4j.create(centers.dataType(), centers.rows()); |  | ||||||
| 
 |  | ||||||
|         if (argMin == null) |  | ||||||
|             argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]); |  | ||||||
| 
 |  | ||||||
|         if (op == null) { |  | ||||||
|             op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); |  | ||||||
|             imin = new ArgMin(distances, argMin); |  | ||||||
|             op.setZ(distances); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         op.setY(point.getArray()); |  | ||||||
| 
 |  | ||||||
|         Nd4j.getExecutioner().exec(op); |  | ||||||
|         Nd4j.getExecutioner().exec(imin); |  | ||||||
| 
 |  | ||||||
|         System.out.println(distances); |  | ||||||
|         return distances; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,150 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class Cluster implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private String id = UUID.randomUUID().toString(); |  | ||||||
|     private String label; |  | ||||||
| 
 |  | ||||||
|     private Point center; |  | ||||||
|     private List<Point> points = Collections.synchronizedList(new ArrayList<Point>()); |  | ||||||
|     private boolean inverse = false; |  | ||||||
|     private Distance distanceFunction; |  | ||||||
| 
 |  | ||||||
|     public Cluster() { |  | ||||||
|         super(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param center |  | ||||||
|      * @param distanceFunction |  | ||||||
|      */ |  | ||||||
|     public Cluster(Point center, Distance distanceFunction) { |  | ||||||
|         this(center, false, distanceFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param center |  | ||||||
|      * @param distanceFunction |  | ||||||
|      */ |  | ||||||
|     public Cluster(Point center, boolean inverse, Distance distanceFunction) { |  | ||||||
|         this.distanceFunction = distanceFunction; |  | ||||||
|         this.inverse = inverse; |  | ||||||
|         setCenter(center); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the distance to the given |  | ||||||
|      * point from the cluster |  | ||||||
|      * @param point the point to get the distance for |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public double getDistanceToCenter(Point point) { |  | ||||||
|         return Nd4j.getExecutioner().execAndReturn( |  | ||||||
|                         ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a point to the cluster |  | ||||||
|      * @param point |  | ||||||
|      */ |  | ||||||
|     public void addPoint(Point point) { |  | ||||||
|         addPoint(point, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a point to the cluster |  | ||||||
|      * @param point the point to add |  | ||||||
|      * @param moveClusterCenter whether to update |  | ||||||
|      *                          the cluster centroid or not |  | ||||||
|      */ |  | ||||||
|     public void addPoint(Point point, boolean moveClusterCenter) { |  | ||||||
|         if (moveClusterCenter) { |  | ||||||
|             if (isInverse()) { |  | ||||||
|                 center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1); |  | ||||||
|             } else { |  | ||||||
|                 center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         getPoints().add(point); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Clear out the ponits |  | ||||||
|      */ |  | ||||||
|     public void removePoints() { |  | ||||||
|         if (getPoints() != null) |  | ||||||
|             getPoints().clear(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Whether the cluster is empty or not |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public boolean isEmpty() { |  | ||||||
|         return points == null || points.isEmpty(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Return the point with the given id |  | ||||||
|      * @param id |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Point getPoint(String id) { |  | ||||||
|         for (Point point : points) |  | ||||||
|             if (id.equals(point.getId())) |  | ||||||
|                 return point; |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Remove the point and return it |  | ||||||
|      * @param id |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Point removePoint(String id) { |  | ||||||
|         Point removePoint = null; |  | ||||||
|         for (Point point : points) |  | ||||||
|             if (id.equals(point.getId())) |  | ||||||
|                 removePoint = point; |  | ||||||
|         if (removePoint != null) |  | ||||||
|             points.remove(removePoint); |  | ||||||
|         return removePoint; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,259 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class ClusterSet implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private Distance distanceFunction; |  | ||||||
|     private List<Cluster> clusters; |  | ||||||
|     private CentersHolder centersHolder; |  | ||||||
|     private Map<String, String> pointDistribution; |  | ||||||
|     private boolean inverse; |  | ||||||
| 
 |  | ||||||
|     public ClusterSet(boolean inverse) { |  | ||||||
|         this(null, inverse, null); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) { |  | ||||||
|         this.distanceFunction = distanceFunction; |  | ||||||
|         this.inverse = inverse; |  | ||||||
|         this.clusters = Collections.synchronizedList(new ArrayList<Cluster>()); |  | ||||||
|         this.pointDistribution = Collections.synchronizedMap(new HashMap<String, String>()); |  | ||||||
|         if (shape != null) |  | ||||||
|             this.centersHolder = new CentersHolder(shape[0], shape[1]); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public boolean isInverse() { |  | ||||||
|         return inverse; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param center |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Cluster addNewClusterWithCenter(Point center) { |  | ||||||
|         Cluster newCluster = new Cluster(center, distanceFunction); |  | ||||||
|         getClusters().add(newCluster); |  | ||||||
|         setPointLocation(center, newCluster); |  | ||||||
|         centersHolder.addCenter(center.getArray()); |  | ||||||
|         return newCluster; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param point |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public PointClassification classifyPoint(Point point) { |  | ||||||
|         return classifyPoint(point, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param points |  | ||||||
|      */ |  | ||||||
|     public void classifyPoints(List<Point> points) { |  | ||||||
|         classifyPoints(points, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param points |  | ||||||
|      * @param moveClusterCenter |  | ||||||
|      */ |  | ||||||
|     public void classifyPoints(List<Point> points, boolean moveClusterCenter) { |  | ||||||
|         for (Point point : points) |  | ||||||
|             classifyPoint(point, moveClusterCenter); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param point |  | ||||||
|      * @param moveClusterCenter |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public PointClassification classifyPoint(Point point, boolean moveClusterCenter) { |  | ||||||
|         Pair<Cluster, Double> nearestCluster = nearestCluster(point); |  | ||||||
|         Cluster newCluster = nearestCluster.getKey(); |  | ||||||
|         boolean locationChange = isPointLocationChange(point, newCluster); |  | ||||||
|         addPointToCluster(point, newCluster, moveClusterCenter); |  | ||||||
|         return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean isPointLocationChange(Point point, Cluster newCluster) { |  | ||||||
|         if (!getPointDistribution().containsKey(point.getId())) |  | ||||||
|             return true; |  | ||||||
|         return !getPointDistribution().get(point.getId()).equals(newCluster.getId()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) { |  | ||||||
|         cluster.addPoint(point, moveClusterCenter); |  | ||||||
|         setPointLocation(point, cluster); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void setPointLocation(Point point, Cluster cluster) { |  | ||||||
|         pointDistribution.put(point.getId(), cluster.getId()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param point |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Pair<Cluster, Double> nearestCluster(Point point) { |  | ||||||
| 
 |  | ||||||
|         /*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE; |  | ||||||
| 
 |  | ||||||
|         double currentDistance; |  | ||||||
|         for (Cluster cluster : getClusters()) { |  | ||||||
|             currentDistance = cluster.getDistanceToCenter(point); |  | ||||||
|             if (isInverse()) { |  | ||||||
|                 if (currentDistance > minDistance) { |  | ||||||
|                     minDistance = currentDistance; |  | ||||||
|                     nearestCluster = cluster; |  | ||||||
|                 } |  | ||||||
|             } else { |  | ||||||
|                 if (currentDistance < minDistance) { |  | ||||||
|                     minDistance = currentDistance; |  | ||||||
|                     nearestCluster = cluster; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         }*/ |  | ||||||
| 
 |  | ||||||
|         Pair<Double, Long> nearestCenterData = centersHolder. |  | ||||||
|                 getCenterByMinDistance(point, distanceFunction); |  | ||||||
|         Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue()); |  | ||||||
|         double minDistance = nearestCenterData.getFirst(); |  | ||||||
|         return Pair.of(nearestCluster, minDistance); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param m1 |  | ||||||
|      * @param m2 |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public double getDistance(Point m1, Point m2) { |  | ||||||
|         return Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param point |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     /*public double getDistanceFromNearestCluster(Point point) { |  | ||||||
|         return nearestCluster(point).getValue(); |  | ||||||
|     }*/ |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterId |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public String getClusterCenterId(String clusterId) { |  | ||||||
|         Point clusterCenter = getClusterCenter(clusterId); |  | ||||||
|         return clusterCenter == null ? null : clusterCenter.getId(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterId |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Point getClusterCenter(String clusterId) { |  | ||||||
|         Cluster cluster = getCluster(clusterId); |  | ||||||
|         return cluster == null ? null : cluster.getCenter(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param id |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Cluster getCluster(String id) { |  | ||||||
|         for (int i = 0, j = clusters.size(); i < j; i++) |  | ||||||
|             if (id.equals(clusters.get(i).getId())) |  | ||||||
|                 return clusters.get(i); |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public int getClusterCount() { |  | ||||||
|         return getClusters() == null ? 0 : getClusters().size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      */ |  | ||||||
|     public void removePoints() { |  | ||||||
|         for (Cluster cluster : getClusters()) |  | ||||||
|             cluster.removePoints(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param count |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public List<Cluster> getMostPopulatedClusters(int count) { |  | ||||||
|         List<Cluster> mostPopulated = new ArrayList<>(clusters); |  | ||||||
|         Collections.sort(mostPopulated, new Comparator<Cluster>() { |  | ||||||
|             public int compare(Cluster o1, Cluster o2) { |  | ||||||
|                 return Integer.compare(o2.getPoints().size(), o1.getPoints().size()); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|         return mostPopulated.subList(0, count); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public List<Cluster> removeEmptyClusters() { |  | ||||||
|         List<Cluster> emptyClusters = new ArrayList<>(); |  | ||||||
|         for (Cluster cluster : clusters) |  | ||||||
|             if (cluster.isEmpty()) |  | ||||||
|                 emptyClusters.add(cluster); |  | ||||||
|         clusters.removeAll(emptyClusters); |  | ||||||
|         return emptyClusters; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,531 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import lombok.val; |  | ||||||
| import org.apache.commons.lang3.ArrayUtils; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.info.ClusterInfo; |  | ||||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; |  | ||||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; |  | ||||||
| import org.deeplearning4j.clustering.strategy.OptimisationStrategy; |  | ||||||
| import org.deeplearning4j.clustering.util.MathUtils; |  | ||||||
| import org.deeplearning4j.clustering.util.MultiThreadUtils; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.ReduceOp; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| import java.util.concurrent.ExecutorService; |  | ||||||
| 
 |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PRIVATE) |  | ||||||
| @Slf4j |  | ||||||
| public class ClusterUtils { |  | ||||||
| 
 |  | ||||||
|     /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */ |  | ||||||
|     public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List<Point> points, |  | ||||||
|                     ExecutorService executorService) { |  | ||||||
|         final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true); |  | ||||||
| 
 |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         for (final Point point : points) { |  | ||||||
|             //tasks.add(new Runnable() { |  | ||||||
|               //  public void run() { |  | ||||||
|                     try { |  | ||||||
|                         PointClassification result = classifyPoint(clusterSet, point); |  | ||||||
|                         if (result.isNewLocation()) |  | ||||||
|                             clusterSetInfo.getPointLocationChange().incrementAndGet(); |  | ||||||
|                         clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter() |  | ||||||
|                                         .put(point.getId(), result.getDistanceFromCenter()); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error classifying point", t); |  | ||||||
|                     } |  | ||||||
|             //    } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
|         return clusterSetInfo; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) { |  | ||||||
|         return clusterSet.classifyPoint(point, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, |  | ||||||
|                     ExecutorService executorService) { |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         int nClusters = clusterSet.getClusterCount(); |  | ||||||
|         for (int i = 0; i < nClusters; i++) { |  | ||||||
|             final Cluster cluster = clusterSet.getClusters().get(i); |  | ||||||
|             //tasks.add(new Runnable() { |  | ||||||
|             //    public void run() { |  | ||||||
|                     try { |  | ||||||
|                         final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); |  | ||||||
|                         refreshClusterCenter(cluster, clusterInfo); |  | ||||||
|                         deriveClusterInfoDistanceStatistics(clusterInfo); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error refreshing cluster centers", t); |  | ||||||
|                     } |  | ||||||
|             //    } |  | ||||||
|             //}); |  | ||||||
|         } |  | ||||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) { |  | ||||||
|         int pointsCount = cluster.getPoints().size(); |  | ||||||
|         if (pointsCount == 0) |  | ||||||
|             return; |  | ||||||
|         Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length())); |  | ||||||
|         for (Point point : cluster.getPoints()) { |  | ||||||
|             INDArray arr = point.getArray(); |  | ||||||
|             if (cluster.isInverse()) |  | ||||||
|                 center.getArray().subi(arr); |  | ||||||
|             else |  | ||||||
|                 center.getArray().addi(arr); |  | ||||||
|         } |  | ||||||
|         center.getArray().divi(pointsCount); |  | ||||||
|         cluster.setCenter(center); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param info |  | ||||||
|      */ |  | ||||||
|     public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) { |  | ||||||
|         int pointCount = info.getPointDistancesFromCenter().size(); |  | ||||||
|         if (pointCount == 0) |  | ||||||
|             return; |  | ||||||
| 
 |  | ||||||
|         double[] distances = |  | ||||||
|                         ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {})); |  | ||||||
|         double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances); |  | ||||||
|         double total = MathUtils.sum(distances); |  | ||||||
|         info.setMaxPointDistanceFromCenter(max); |  | ||||||
|         info.setTotalPointDistanceFromCenter(total); |  | ||||||
|         info.setAveragePointDistanceFromCenter(total / pointCount); |  | ||||||
|         info.setPointDistanceFromCenterVariance(MathUtils.variance(distances)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param points |  | ||||||
|      * @param previousDxs |  | ||||||
|      * @param executorService |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet, |  | ||||||
|                     final List<Point> points, INDArray previousDxs, ExecutorService executorService) { |  | ||||||
|         final int pointsCount = points.size(); |  | ||||||
|         final INDArray dxs = Nd4j.create(pointsCount); |  | ||||||
|         final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); |  | ||||||
| 
 |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < pointsCount; i++) { |  | ||||||
|             final int i2 = i; |  | ||||||
|             //tasks.add(new Runnable() { |  | ||||||
|             //    public void run() { |  | ||||||
|                     try { |  | ||||||
|                         Point point = points.get(i2); |  | ||||||
|                         double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point) |  | ||||||
|                                 : Math.pow(newCluster.getDistanceToCenter(point), 2); |  | ||||||
|                         dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error computing squared distance from nearest cluster", t); |  | ||||||
|                     } |  | ||||||
|             //    } |  | ||||||
|             //}); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
|         for (int i = 0; i < pointsCount; i++) { |  | ||||||
|             double previousMinDistance = previousDxs.getDouble(i); |  | ||||||
|             if (clusterSet.isInverse()) { |  | ||||||
|                 if (dxs.getDouble(i) < previousMinDistance) { |  | ||||||
| 
 |  | ||||||
|                     dxs.putScalar(i, previousMinDistance); |  | ||||||
|                 } |  | ||||||
|             } else if (dxs.getDouble(i) > previousMinDistance) |  | ||||||
|                 dxs.putScalar(i, previousMinDistance); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return dxs; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet, |  | ||||||
|                                                                     final List<Point> points, INDArray previousDxs) { |  | ||||||
|         final int pointsCount = points.size(); |  | ||||||
|         final INDArray dxs = Nd4j.create(pointsCount); |  | ||||||
|         final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1); |  | ||||||
| 
 |  | ||||||
|         Double sum = new Double(0); |  | ||||||
|         for (int i = 0; i < pointsCount; i++) { |  | ||||||
| 
 |  | ||||||
|                 Point point = points.get(i); |  | ||||||
|                 double dist = Math.pow(newCluster.getDistanceToCenter(point), 2); |  | ||||||
|                 sum += dist; |  | ||||||
|                 dxs.putScalar(i, sum); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return dxs; |  | ||||||
|     } |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) { |  | ||||||
|         ExecutorService executor = MultiThreadUtils.newExecutorService(); |  | ||||||
|         ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor); |  | ||||||
|         executor.shutdownNow(); |  | ||||||
|         return info; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) { |  | ||||||
|         final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true); |  | ||||||
|         int clusterCount = clusterSet.getClusterCount(); |  | ||||||
| 
 |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < clusterCount; i++) { |  | ||||||
|             final Cluster cluster = clusterSet.getClusters().get(i); |  | ||||||
|             //tasks.add(new Runnable() { |  | ||||||
|             //    public void run() { |  | ||||||
|                     try { |  | ||||||
|                         info.getClustersInfos().put(cluster.getId(), |  | ||||||
|                                 computeClusterInfos(cluster, clusterSet.getDistanceFunction())); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error computing cluster set info", t); |  | ||||||
|                     } |  | ||||||
|                 //} |  | ||||||
|             //}); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
| 
 |  | ||||||
|         //tasks = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < clusterCount; i++) { |  | ||||||
|             final int clusterIdx = i; |  | ||||||
|             final Cluster fromCluster = clusterSet.getClusters().get(i); |  | ||||||
|             //tasks.add(new Runnable() { |  | ||||||
|                 //public void run() { |  | ||||||
|                     try { |  | ||||||
|                         for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) { |  | ||||||
|                             Cluster toCluster = clusterSet.getClusters().get(k); |  | ||||||
|                             double distance = Nd4j.getExecutioner() |  | ||||||
|                                             .execAndReturn(ClusterUtils.createDistanceFunctionOp( |  | ||||||
|                                                             clusterSet.getDistanceFunction(), |  | ||||||
|                                                             fromCluster.getCenter().getArray(), |  | ||||||
|                                                             toCluster.getCenter().getArray())) |  | ||||||
|                                             .getFinalResult().doubleValue(); |  | ||||||
|                             info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(), |  | ||||||
|                                             distance); |  | ||||||
|                         } |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error computing distances", t); |  | ||||||
|                     } |  | ||||||
|             //    } |  | ||||||
|             //}); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
| 
 |  | ||||||
|         return info; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param cluster |  | ||||||
|      * @param distanceFunction |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) { |  | ||||||
|         ClusterInfo info = new ClusterInfo(cluster.isInverse(), true); |  | ||||||
|         for (int i = 0, j = cluster.getPoints().size(); i < j; i++) { |  | ||||||
|             Point point = cluster.getPoints().get(i); |  | ||||||
|             //shouldn't need to inverse here. other parts of |  | ||||||
|             //the code should interpret the "distance" or score here |  | ||||||
|             double distance = Nd4j.getExecutioner() |  | ||||||
|                             .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, |  | ||||||
|                                             cluster.getCenter().getArray(), point.getArray())) |  | ||||||
|                             .getFinalResult().doubleValue(); |  | ||||||
|             info.getPointDistancesFromCenter().put(point.getId(), distance); |  | ||||||
|             double diff = info.getTotalPointDistanceFromCenter() + distance; |  | ||||||
|             info.setTotalPointDistanceFromCenter(diff); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (!cluster.getPoints().isEmpty()) |  | ||||||
|             info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size()); |  | ||||||
|         return info; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param optimization |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param executor |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet, |  | ||||||
|                     ClusterSetInfo clusterSetInfo, ExecutorService executor) { |  | ||||||
| 
 |  | ||||||
|         if (optimization.isClusteringOptimizationType( |  | ||||||
|                         ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) { |  | ||||||
|             int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, |  | ||||||
|                             clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); |  | ||||||
|             return splitCount > 0; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (optimization.isClusteringOptimizationType( |  | ||||||
|                         ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) { |  | ||||||
|             int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, |  | ||||||
|                             clusterSetInfo, optimization.getClusteringOptimizationValue(), executor); |  | ||||||
|             return splitCount > 0; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param info |  | ||||||
|      * @param count |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static List<Cluster> getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info, |  | ||||||
|                     int count) { |  | ||||||
|         List<Cluster> clusters = new ArrayList<>(clusterSet.getClusters()); |  | ||||||
|         Collections.sort(clusters, new Comparator<Cluster>() { |  | ||||||
|             public int compare(Cluster o1, Cluster o2) { |  | ||||||
|                 Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter(); |  | ||||||
|                 Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter(); |  | ||||||
|                 int comp = o1TotalDistance.compareTo(o2TotalDistance); |  | ||||||
|                 return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp; |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         return clusters.subList(0, count); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param info |  | ||||||
|      * @param maximumAverageDistance |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static List<Cluster> getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet, |  | ||||||
|                     final ClusterSetInfo info, double maximumAverageDistance) { |  | ||||||
|         List<Cluster> clusters = new ArrayList<>(); |  | ||||||
|         for (Cluster cluster : clusterSet.getClusters()) { |  | ||||||
|             ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); |  | ||||||
|             if (clusterInfo != null) { |  | ||||||
|                 //distances |  | ||||||
|                 if (clusterInfo.isInverse()) { |  | ||||||
|                     if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance) |  | ||||||
|                         clusters.add(cluster); |  | ||||||
|                 } else { |  | ||||||
|                     if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance) |  | ||||||
|                         clusters.add(cluster); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|         return clusters; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param info |  | ||||||
|      * @param maximumDistance |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static List<Cluster> getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet, |  | ||||||
|                     final ClusterSetInfo info, double maximumDistance) { |  | ||||||
|         List<Cluster> clusters = new ArrayList<>(); |  | ||||||
|         for (Cluster cluster : clusterSet.getClusters()) { |  | ||||||
|             ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId()); |  | ||||||
|             if (clusterInfo != null) { |  | ||||||
|                 if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) { |  | ||||||
|                     clusters.add(cluster); |  | ||||||
|                 } else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) { |  | ||||||
|                     clusters.add(cluster); |  | ||||||
| 
 |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         return clusters; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param count |  | ||||||
|      * @param executorService |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, |  | ||||||
|                     ExecutorService executorService) { |  | ||||||
|         List<Cluster> clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count); |  | ||||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); |  | ||||||
|         return clustersToSplit.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param maxWithinClusterDistance |  | ||||||
|      * @param executorService |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet, |  | ||||||
|                     ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { |  | ||||||
|         List<Cluster> clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, |  | ||||||
|                         maxWithinClusterDistance); |  | ||||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); |  | ||||||
|         return clustersToSplit.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param maxWithinClusterDistance |  | ||||||
|      * @param executorService |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet, |  | ||||||
|                     ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) { |  | ||||||
|         List<Cluster> clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo, |  | ||||||
|                         maxWithinClusterDistance); |  | ||||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService); |  | ||||||
|         return clustersToSplit.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param count |  | ||||||
|      * @param executorService |  | ||||||
|      */ |  | ||||||
|     public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count, |  | ||||||
|                     ExecutorService executorService) { |  | ||||||
|         List<Cluster> clustersToSplit = clusterSet.getMostPopulatedClusters(count); |  | ||||||
|         splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param clusters |  | ||||||
|      * @param maxDistance |  | ||||||
|      * @param executorService |  | ||||||
|      */ |  | ||||||
|     public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, |  | ||||||
|                     List<Cluster> clusters, final double maxDistance, ExecutorService executorService) { |  | ||||||
|         final Random random = new Random(); |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         for (final Cluster cluster : clusters) { |  | ||||||
|             tasks.add(new Runnable() { |  | ||||||
|                 public void run() { |  | ||||||
|                     try { |  | ||||||
|                         ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId()); |  | ||||||
|                         List<String> fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance); |  | ||||||
|                         int rank = Math.min(fartherPoints.size(), 3); |  | ||||||
|                         String pointId = fartherPoints.get(random.nextInt(rank)); |  | ||||||
|                         Point point = cluster.removePoint(pointId); |  | ||||||
|                         clusterSet.addNewClusterWithCenter(point); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error splitting clusters", t); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
|         MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param clusterSetInfo |  | ||||||
|      * @param clusters |  | ||||||
|      * @param executorService |  | ||||||
|      */ |  | ||||||
|     public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo, |  | ||||||
|                     List<Cluster> clusters, ExecutorService executorService) { |  | ||||||
|         final Random random = new Random(); |  | ||||||
|         List<Runnable> tasks = new ArrayList<>(); |  | ||||||
|         for (final Cluster cluster : clusters) { |  | ||||||
|             tasks.add(new Runnable() { |  | ||||||
|                 public void run() { |  | ||||||
|                     try { |  | ||||||
|                         Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size())); |  | ||||||
|                         clusterSet.addNewClusterWithCenter(point); |  | ||||||
|                     } catch (Throwable t) { |  | ||||||
|                         log.warn("Error Splitting clusters (2)", t); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         MultiThreadUtils.parallelTasks(tasks, executorService); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){ |  | ||||||
|         val op = createDistanceFunctionOp(distanceFunction, x, y); |  | ||||||
|         op.setDimensions(dimensions); |  | ||||||
|         return op; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){ |  | ||||||
|         switch (distanceFunction){ |  | ||||||
|             case COSINE_DISTANCE: |  | ||||||
|                 return new CosineDistance(x,y); |  | ||||||
|             case COSINE_SIMILARITY: |  | ||||||
|                 return new CosineSimilarity(x,y); |  | ||||||
|             case DOT: |  | ||||||
|                 return new Dot(x,y); |  | ||||||
|             case EUCLIDEAN: |  | ||||||
|                 return new EuclideanDistance(x,y); |  | ||||||
|             case JACCARD: |  | ||||||
|                 return new JaccardDistance(x,y); |  | ||||||
|             case MANHATTAN: |  | ||||||
|                 return new ManhattanDistance(x,y); |  | ||||||
|             default: |  | ||||||
|                 throw new IllegalStateException("Unknown distance function: " + distanceFunction); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,107 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.UUID; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @Data |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class Point implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private static final long serialVersionUID = -6658028541426027226L; |  | ||||||
| 
 |  | ||||||
|     private String id = UUID.randomUUID().toString(); |  | ||||||
|     private String label; |  | ||||||
|     private INDArray array; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param array |  | ||||||
|      */ |  | ||||||
|     public Point(INDArray array) { |  | ||||||
|         super(); |  | ||||||
|         this.array = array; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param id |  | ||||||
|      * @param array |  | ||||||
|      */ |  | ||||||
|     public Point(String id, INDArray array) { |  | ||||||
|         super(); |  | ||||||
|         this.id = id; |  | ||||||
|         this.array = array; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Point(String id, String label, double[] data) { |  | ||||||
|         this(id, label, Nd4j.create(data)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Point(String id, String label, INDArray array) { |  | ||||||
|         super(); |  | ||||||
|         this.id = id; |  | ||||||
|         this.label = label; |  | ||||||
|         this.array = array; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param matrix |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static List<Point> toPoints(INDArray matrix) { |  | ||||||
|         List<Point> arr = new ArrayList<>(matrix.rows()); |  | ||||||
|         for (int i = 0; i < matrix.rows(); i++) { |  | ||||||
|             arr.add(new Point(matrix.getRow(i))); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return arr; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param vectors |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static List<Point> toPoints(List<INDArray> vectors) { |  | ||||||
|         List<Point> points = new ArrayList<>(); |  | ||||||
|         for (INDArray vector : vectors) |  | ||||||
|             points.add(new Point(vector)); |  | ||||||
|         return points; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,40 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class PointClassification implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private Cluster cluster; |  | ||||||
|     private double distanceFromCenter; |  | ||||||
|     private boolean newLocation; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,37 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.condition; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public interface ClusteringAlgorithmCondition { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationHistory |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean isSatisfied(IterationHistory iterationHistory); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,69 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.condition; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.Condition; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.LessThan; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| @AllArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable { |  | ||||||
| 
 |  | ||||||
|     private Condition convergenceCondition; |  | ||||||
|     private double pointsDistributionChangeRate; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param pointsDistributionChangeRate |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) { |  | ||||||
|         Condition condition = new LessThan(pointsDistributionChangeRate); |  | ||||||
|         return new ConvergenceCondition(condition, pointsDistributionChangeRate); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationHistory |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { |  | ||||||
|         int iterationCount = iterationHistory.getIterationCount(); |  | ||||||
|         if (iterationCount <= 1) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get(); |  | ||||||
|         variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount(); |  | ||||||
| 
 |  | ||||||
|         return convergenceCondition.apply(variation); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.condition; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.Condition; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable { |  | ||||||
| 
 |  | ||||||
|     private Condition iterationCountCondition; |  | ||||||
| 
 |  | ||||||
|     protected FixedIterationCountCondition(int initialClusterCount) { |  | ||||||
|         iterationCountCondition = new GreaterThanOrEqual(initialClusterCount); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationCount |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) { |  | ||||||
|         return new FixedIterationCountCondition(iterationCount); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationHistory |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { |  | ||||||
|         return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,82 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.condition; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.Condition; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.LessThan; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable { |  | ||||||
| 
 |  | ||||||
|     private Condition varianceVariationCondition; |  | ||||||
|     private int period; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param varianceVariation |  | ||||||
|      * @param period |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) { |  | ||||||
|         Condition condition = new LessThan(varianceVariation); |  | ||||||
|         return new VarianceVariationCondition(condition, period); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationHistory |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public boolean isSatisfied(IterationHistory iterationHistory) { |  | ||||||
|         if (iterationHistory.getIterationCount() <= period) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) { |  | ||||||
|             double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo() |  | ||||||
|                             .getPointDistanceFromClusterVariance(); |  | ||||||
|             variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() |  | ||||||
|                             .getPointDistanceFromClusterVariance(); |  | ||||||
|             variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo() |  | ||||||
|                             .getPointDistanceFromClusterVariance(); |  | ||||||
| 
 |  | ||||||
|             if (!varianceVariationCondition.apply(variation)) |  | ||||||
|                 return false; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return true; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,114 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.info; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.*; |  | ||||||
| import java.util.concurrent.ConcurrentHashMap; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @Data |  | ||||||
| public class ClusterInfo implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private double averagePointDistanceFromCenter; |  | ||||||
|     private double maxPointDistanceFromCenter; |  | ||||||
|     private double pointDistanceFromCenterVariance; |  | ||||||
|     private double totalPointDistanceFromCenter; |  | ||||||
|     private boolean inverse; |  | ||||||
|     private Map<String, Double> pointDistancesFromCenter = new ConcurrentHashMap<>(); |  | ||||||
| 
 |  | ||||||
|     public ClusterInfo(boolean inverse) { |  | ||||||
|         this(false, inverse); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param threadSafe |  | ||||||
|      */ |  | ||||||
|     public ClusterInfo(boolean threadSafe, boolean inverse) { |  | ||||||
|         super(); |  | ||||||
|         this.inverse = inverse; |  | ||||||
|         if (threadSafe) { |  | ||||||
|             pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Set<Map.Entry<String, Double>> getSortedPointDistancesFromCenter() { |  | ||||||
|         SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() { |  | ||||||
|             @Override |  | ||||||
|             public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) { |  | ||||||
|                 int res = e1.getValue().compareTo(e2.getValue()); |  | ||||||
|                 return res != 0 ? res : 1; |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|         sortedEntries.addAll(pointDistancesFromCenter.entrySet()); |  | ||||||
|         return sortedEntries; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Set<Map.Entry<String, Double>> getReverseSortedPointDistancesFromCenter() { |  | ||||||
|         SortedSet<Map.Entry<String, Double>> sortedEntries = new TreeSet<>(new Comparator<Map.Entry<String, Double>>() { |  | ||||||
|             @Override |  | ||||||
|             public int compare(Map.Entry<String, Double> e1, Map.Entry<String, Double> e2) { |  | ||||||
|                 int res = e1.getValue().compareTo(e2.getValue()); |  | ||||||
|                 return -(res != 0 ? res : 1); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
|         sortedEntries.addAll(pointDistancesFromCenter.entrySet()); |  | ||||||
|         return sortedEntries; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param maxDistance |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public List<String> getPointsFartherFromCenterThan(double maxDistance) { |  | ||||||
|         Set<Map.Entry<String, Double>> sorted = getReverseSortedPointDistancesFromCenter(); |  | ||||||
|         List<String> ids = new ArrayList<>(); |  | ||||||
|         for (Map.Entry<String, Double> entry : sorted) { |  | ||||||
|             if (inverse && entry.getValue() < -maxDistance) { |  | ||||||
|                 if (entry.getValue() < -maxDistance) |  | ||||||
|                     break; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             else if (entry.getValue() > maxDistance) |  | ||||||
|                 break; |  | ||||||
| 
 |  | ||||||
|             ids.add(entry.getKey()); |  | ||||||
|         } |  | ||||||
|         return ids; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,142 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.info; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.shade.guava.collect.HashBasedTable; |  | ||||||
| import org.nd4j.shade.guava.collect.Table; |  | ||||||
| import org.deeplearning4j.clustering.cluster.Cluster; |  | ||||||
| import org.deeplearning4j.clustering.cluster.ClusterSet; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.HashMap; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.Map; |  | ||||||
| import java.util.concurrent.atomic.AtomicInteger; |  | ||||||
| 
 |  | ||||||
| public class ClusterSetInfo implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private Map<String, ClusterInfo> clustersInfos = new HashMap<>(); |  | ||||||
|     private Table<String, String, Double> distancesBetweenClustersCenters = HashBasedTable.create(); |  | ||||||
|     private AtomicInteger pointLocationChange; |  | ||||||
|     private boolean threadSafe; |  | ||||||
|     private boolean inverse; |  | ||||||
| 
 |  | ||||||
|     public ClusterSetInfo(boolean inverse) { |  | ||||||
|         this(inverse, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param inverse |  | ||||||
|      * @param threadSafe |  | ||||||
|      */ |  | ||||||
|     public ClusterSetInfo(boolean inverse, boolean threadSafe) { |  | ||||||
|         this.pointLocationChange = new AtomicInteger(0); |  | ||||||
|         this.threadSafe = threadSafe; |  | ||||||
|         this.inverse = inverse; |  | ||||||
|         if (threadSafe) { |  | ||||||
|             clustersInfos = Collections.synchronizedMap(clustersInfos); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterSet |  | ||||||
|      * @param threadSafe |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) { |  | ||||||
|         ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe); |  | ||||||
|         for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++) |  | ||||||
|             info.addClusterInfo(clusterSet.getClusters().get(i).getId()); |  | ||||||
|         return info; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void removeClusterInfos(List<Cluster> clusters) { |  | ||||||
|         for (Cluster cluster : clusters) { |  | ||||||
|             clustersInfos.remove(cluster.getId()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public ClusterInfo addClusterInfo(String clusterId) { |  | ||||||
|         ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe); |  | ||||||
|         clustersInfos.put(clusterId, clusterInfo); |  | ||||||
|         return clusterInfo; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public ClusterInfo getClusterInfo(String clusterId) { |  | ||||||
|         return clustersInfos.get(clusterId); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getAveragePointDistanceFromClusterCenter() { |  | ||||||
|         if (clustersInfos == null || clustersInfos.isEmpty()) |  | ||||||
|             return 0; |  | ||||||
| 
 |  | ||||||
|         double average = 0; |  | ||||||
|         for (ClusterInfo info : clustersInfos.values()) |  | ||||||
|             average += info.getAveragePointDistanceFromCenter(); |  | ||||||
|         return average / clustersInfos.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getPointDistanceFromClusterVariance() { |  | ||||||
|         if (clustersInfos == null || clustersInfos.isEmpty()) |  | ||||||
|             return 0; |  | ||||||
| 
 |  | ||||||
|         double average = 0; |  | ||||||
|         for (ClusterInfo info : clustersInfos.values()) |  | ||||||
|             average += info.getPointDistanceFromCenterVariance(); |  | ||||||
|         return average / clustersInfos.size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getPointsCount() { |  | ||||||
|         int count = 0; |  | ||||||
|         for (ClusterInfo clusterInfo : clustersInfos.values()) |  | ||||||
|             count += clusterInfo.getPointDistancesFromCenter().size(); |  | ||||||
|         return count; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Map<String, ClusterInfo> getClustersInfos() { |  | ||||||
|         return clustersInfos; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setClustersInfos(Map<String, ClusterInfo> clustersInfos) { |  | ||||||
|         this.clustersInfos = clustersInfos; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Table<String, String, Double> getDistancesBetweenClustersCenters() { |  | ||||||
|         return distancesBetweenClustersCenters; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setDistancesBetweenClustersCenters(Table<String, String, Double> interClusterDistances) { |  | ||||||
|         this.distancesBetweenClustersCenters = interClusterDistances; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public AtomicInteger getPointLocationChange() { |  | ||||||
|         return pointLocationChange; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setPointLocationChange(AtomicInteger pointLocationChange) { |  | ||||||
|         this.pointLocationChange = pointLocationChange; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,72 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.iteration; |  | ||||||
| 
 |  | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.Setter; |  | ||||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.HashMap; |  | ||||||
| import java.util.Map; |  | ||||||
| 
 |  | ||||||
| public class IterationHistory implements Serializable { |  | ||||||
|     @Getter |  | ||||||
|     @Setter |  | ||||||
|     private Map<Integer, IterationInfo> iterationsInfos = new HashMap<>(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public ClusterSetInfo getMostRecentClusterSetInfo() { |  | ||||||
|         IterationInfo iterationInfo = getMostRecentIterationInfo(); |  | ||||||
|         return iterationInfo == null ? null : iterationInfo.getClusterSetInfo(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public IterationInfo getMostRecentIterationInfo() { |  | ||||||
|         return getIterationInfo(getIterationCount() - 1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public int getIterationCount() { |  | ||||||
|         return getIterationsInfos().size(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationIdx |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public IterationInfo getIterationInfo(int iterationIdx) { |  | ||||||
|         return getIterationsInfos().get(iterationIdx); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,49 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.iteration; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.deeplearning4j.clustering.info.ClusterSetInfo; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class IterationInfo implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private int index; |  | ||||||
|     private ClusterSetInfo clusterSetInfo; |  | ||||||
|     private boolean strategyApplied; |  | ||||||
| 
 |  | ||||||
|     public IterationInfo(int index) { |  | ||||||
|         super(); |  | ||||||
|         this.index = index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public IterationInfo(int index, ClusterSetInfo clusterSetInfo) { |  | ||||||
|         super(); |  | ||||||
|         this.index = index; |  | ||||||
|         this.clusterSetInfo = clusterSetInfo; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,142 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.kdtree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.custom.KnnMinDistance; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| public class HyperRect implements Serializable { |  | ||||||
| 
 |  | ||||||
|     //private List<Interval> points; |  | ||||||
|     private float[] lowerEnds; |  | ||||||
|     private float[] higherEnds; |  | ||||||
|     private INDArray lowerEndsIND; |  | ||||||
|     private INDArray higherEndsIND; |  | ||||||
| 
 |  | ||||||
|     public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) { |  | ||||||
|         this.lowerEnds = new float[lowerEndsIn.length]; |  | ||||||
|         this.higherEnds = new float[lowerEndsIn.length]; |  | ||||||
|         System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length); |  | ||||||
|         System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length); |  | ||||||
|         lowerEndsIND = Nd4j.createFromArray(lowerEnds); |  | ||||||
|         higherEndsIND = Nd4j.createFromArray(higherEnds); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public HyperRect(float[] point) { |  | ||||||
|         this(point, point); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public HyperRect(Pair<float[], float[]> ends) { |  | ||||||
|         this(ends.getFirst(), ends.getSecond()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public void enlargeTo(INDArray point) { |  | ||||||
|         float[] pointAsArray = point.toFloatVector(); |  | ||||||
|         for (int i = 0; i < lowerEnds.length; i++) { |  | ||||||
|             float p = pointAsArray[i]; |  | ||||||
|             if (lowerEnds[i] > p) |  | ||||||
|                 lowerEnds[i] = p; |  | ||||||
|             else if (higherEnds[i] < p) |  | ||||||
|                 higherEnds[i] = p; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static Pair<float[],float[]> point(INDArray vector) { |  | ||||||
|         Pair<float[],float[]> ret = new Pair<>(); |  | ||||||
|         float[] curr = new float[(int)vector.length()]; |  | ||||||
|         for (int i = 0; i < vector.length(); i++) { |  | ||||||
|             curr[i] = vector.getFloat(i); |  | ||||||
|         } |  | ||||||
|         ret.setFirst(curr); |  | ||||||
|         ret.setSecond(curr); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /*public List<Boolean> contains(INDArray hPoint) { |  | ||||||
|         List<Boolean> ret = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < hPoint.length(); i++) { |  | ||||||
|             ret.add(lowerEnds[i] <= hPoint.getDouble(i) && |  | ||||||
|                     higherEnds[i] >= hPoint.getDouble(i)); |  | ||||||
|         } |  | ||||||
|         return ret; |  | ||||||
|     }*/ |  | ||||||
| 
 |  | ||||||
|     public double minDistance(INDArray hPoint, INDArray output) { |  | ||||||
|         Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output)); |  | ||||||
|         return output.getFloat(0); |  | ||||||
| 
 |  | ||||||
|         /*double ret = 0.0; |  | ||||||
|         double[] pointAsArray = hPoint.toDoubleVector(); |  | ||||||
|         for (int i = 0; i < pointAsArray.length; i++) { |  | ||||||
|            double p = pointAsArray[i]; |  | ||||||
|            if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) { |  | ||||||
|               if (p < lowerEnds[i]) |  | ||||||
|                  ret += Math.pow((p - lowerEnds[i]), 2); |  | ||||||
|               else |  | ||||||
|                  ret += Math.pow((p - higherEnds[i]), 2); |  | ||||||
|            } |  | ||||||
|         } |  | ||||||
|         ret = Math.pow(ret, 0.5); |  | ||||||
|         return ret;*/ |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public HyperRect getUpper(INDArray hPoint, int desc) { |  | ||||||
|         //Interval interval = points.get(desc); |  | ||||||
|         float higher = higherEnds[desc]; |  | ||||||
|         float d = hPoint.getFloat(desc); |  | ||||||
|         if (higher < d) |  | ||||||
|             return null; |  | ||||||
|         HyperRect ret = new HyperRect(lowerEnds,higherEnds); |  | ||||||
|         if (ret.lowerEnds[desc] < d) |  | ||||||
|             ret.lowerEnds[desc] = d; |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public HyperRect getLower(INDArray hPoint, int desc) { |  | ||||||
|         //Interval interval = points.get(desc); |  | ||||||
|         float lower = lowerEnds[desc]; |  | ||||||
|         float d = hPoint.getFloat(desc); |  | ||||||
|         if (lower > d) |  | ||||||
|             return null; |  | ||||||
|         HyperRect ret = new HyperRect(lowerEnds,higherEnds); |  | ||||||
|         //Interval i2 = ret.points.get(desc); |  | ||||||
|         if (ret.higherEnds[desc] > d) |  | ||||||
|             ret.higherEnds[desc] = d; |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String toString() { |  | ||||||
|         String retVal = ""; |  | ||||||
|         retVal +=  "["; |  | ||||||
|         for (int i = 0; i < lowerEnds.length; ++i) { |  | ||||||
|             retVal +=  "("  + lowerEnds[i] + " - " + higherEnds[i] + ") "; |  | ||||||
|         } |  | ||||||
|         retVal +=  "]"; |  | ||||||
|         return retVal; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,370 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.kdtree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce.bool.Any; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collections; |  | ||||||
| import java.util.Comparator; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class KDTree implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private KDNode root; |  | ||||||
|     private int dims = 100; |  | ||||||
|     public final static int GREATER = 1; |  | ||||||
|     public final static int LESS = 0; |  | ||||||
|     private int size = 0; |  | ||||||
|     private HyperRect rect; |  | ||||||
| 
 |  | ||||||
|     public KDTree(int dims) { |  | ||||||
|         this.dims = dims; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Insert a point in to the tree |  | ||||||
|      * @param point the point to insert |  | ||||||
|      */ |  | ||||||
|     public void insert(INDArray point) { |  | ||||||
|         if (!point.isVector() || point.length() != dims) |  | ||||||
|             throw new IllegalArgumentException("Point must be a vector of length " + dims); |  | ||||||
| 
 |  | ||||||
|         if (root == null) { |  | ||||||
|             root = new KDNode(point); |  | ||||||
|             rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector()); |  | ||||||
|         } else { |  | ||||||
|             int disc = 0; |  | ||||||
|             KDNode node = root; |  | ||||||
|             KDNode insert = new KDNode(point); |  | ||||||
|             int successor; |  | ||||||
|             while (true) { |  | ||||||
|                 //exactly equal |  | ||||||
|                 INDArray pt = node.getPoint(); |  | ||||||
|                 INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z(); |  | ||||||
|                 if (countEq.getInt(0) == 0) { |  | ||||||
|                     return; |  | ||||||
|                 } else { |  | ||||||
|                     successor = successor(node, point, disc); |  | ||||||
|                     KDNode child; |  | ||||||
|                     if (successor < 1) |  | ||||||
|                         child = node.getLeft(); |  | ||||||
|                     else |  | ||||||
|                         child = node.getRight(); |  | ||||||
|                     if (child == null) |  | ||||||
|                         break; |  | ||||||
|                     disc = (disc + 1) % dims; |  | ||||||
|                     node = child; |  | ||||||
|                 } |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (successor < 1) |  | ||||||
|                 node.setLeft(insert); |  | ||||||
| 
 |  | ||||||
|             else |  | ||||||
|                 node.setRight(insert); |  | ||||||
| 
 |  | ||||||
|             rect.enlargeTo(point); |  | ||||||
|             insert.setParent(node); |  | ||||||
|         } |  | ||||||
|         size++; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public INDArray delete(INDArray point) { |  | ||||||
|         KDNode node = root; |  | ||||||
|         int _disc = 0; |  | ||||||
|         while (node != null) { |  | ||||||
|             if (node.point == point) |  | ||||||
|                 break; |  | ||||||
|             int successor = successor(node, point, _disc); |  | ||||||
|             if (successor < 1) |  | ||||||
|                 node = node.getLeft(); |  | ||||||
|             else |  | ||||||
|                 node = node.getRight(); |  | ||||||
|             _disc = (_disc + 1) % dims; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (node != null) { |  | ||||||
|             if (node == root) { |  | ||||||
|                 root = delete(root, _disc); |  | ||||||
|             } else |  | ||||||
|                 node = delete(node, _disc); |  | ||||||
|             size--; |  | ||||||
|             if (size == 1) { |  | ||||||
|                 rect = new HyperRect(HyperRect.point(point)); |  | ||||||
|             } else if (size == 0) |  | ||||||
|                 rect = null; |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|         return node.getPoint(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // Share this data for recursive calls of "knn" |  | ||||||
|     private float currentDistance; |  | ||||||
|     private INDArray currentPoint; |  | ||||||
|     private INDArray minDistance = Nd4j.scalar(0.f); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public List<Pair<Float, INDArray>> knn(INDArray point, float distance) { |  | ||||||
|         List<Pair<Float, INDArray>> best = new ArrayList<>(); |  | ||||||
|         currentDistance = distance; |  | ||||||
|         currentPoint = point; |  | ||||||
|         knn(root, rect, best, 0); |  | ||||||
|         Collections.sort(best, new Comparator<Pair<Float, INDArray>>() { |  | ||||||
|             @Override |  | ||||||
|             public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) { |  | ||||||
|                 return Float.compare(o1.getKey(), o2.getKey()); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         return best; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) { |  | ||||||
|         if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance) |  | ||||||
|             return; |  | ||||||
|         int _discNext = (_disc + 1) % dims; |  | ||||||
|         float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult() |  | ||||||
|                 .floatValue(); |  | ||||||
| 
 |  | ||||||
|         if (distance <= currentDistance) { |  | ||||||
|             best.add(Pair.of(distance, node.getPoint())); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         HyperRect lower = rect.getLower(node.point, _disc); |  | ||||||
|         HyperRect upper = rect.getUpper(node.point, _disc); |  | ||||||
|         knn(node.getLeft(), lower, best, _discNext); |  | ||||||
|         knn(node.getRight(), upper, best, _discNext); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query for nearest neighbor. Returns the distance and point |  | ||||||
|      * @param point the point to query for |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Pair<Double, INDArray> nn(INDArray point) { |  | ||||||
|         return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best, |  | ||||||
|                     int _disc) { |  | ||||||
|         if (node == null || rect.minDistance(point, minDistance) > dist) |  | ||||||
|             return Pair.of(Double.POSITIVE_INFINITY, null); |  | ||||||
| 
 |  | ||||||
|         int _discNext = (_disc + 1) % dims; |  | ||||||
|         double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue(); |  | ||||||
|         if (dist2 < dist) { |  | ||||||
|             best = node.getPoint(); |  | ||||||
|             dist = dist2; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         HyperRect lower = rect.getLower(node.point, _disc); |  | ||||||
|         HyperRect upper = rect.getUpper(node.point, _disc); |  | ||||||
| 
 |  | ||||||
|         if (point.getDouble(_disc) < node.point.getDouble(_disc)) { |  | ||||||
|             Pair<Double, INDArray> left = nn(node.getLeft(), point, lower, dist, best, _discNext); |  | ||||||
|             Pair<Double, INDArray> right = nn(node.getRight(), point, upper, dist, best, _discNext); |  | ||||||
|             if (left.getKey() < dist) |  | ||||||
|                 return left; |  | ||||||
|             else if (right.getKey() < dist) |  | ||||||
|                 return right; |  | ||||||
| 
 |  | ||||||
|         } else { |  | ||||||
|             Pair<Double, INDArray> left = nn(node.getRight(), point, upper, dist, best, _discNext); |  | ||||||
|             Pair<Double, INDArray> right = nn(node.getLeft(), point, lower, dist, best, _discNext); |  | ||||||
|             if (left.getKey() < dist) |  | ||||||
|                 return left; |  | ||||||
|             else if (right.getKey() < dist) |  | ||||||
|                 return right; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return Pair.of(dist, best); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private KDNode delete(KDNode delete, int _disc) { |  | ||||||
|         if (delete.getLeft() != null && delete.getRight() != null) { |  | ||||||
|             if (delete.getParent() != null) { |  | ||||||
|                 if (delete.getParent().getLeft() == delete) |  | ||||||
|                     delete.getParent().setLeft(null); |  | ||||||
|                 else |  | ||||||
|                     delete.getParent().setRight(null); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|             return null; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         int disc = _disc; |  | ||||||
|         _disc = (_disc + 1) % dims; |  | ||||||
|         Pair<KDNode, Integer> qd = null; |  | ||||||
|         if (delete.getRight() != null) { |  | ||||||
|             qd = min(delete.getRight(), disc, _disc); |  | ||||||
|         } else if (delete.getLeft() != null) |  | ||||||
|             qd = max(delete.getLeft(), disc, _disc); |  | ||||||
|         if (qd == null) {// is leaf |  | ||||||
|             return null; |  | ||||||
|         } |  | ||||||
|         delete.point = qd.getKey().point; |  | ||||||
|         KDNode qFather = qd.getKey().getParent(); |  | ||||||
|         if (qFather.getLeft() == qd.getKey()) { |  | ||||||
|             qFather.setLeft(delete(qd.getKey(), disc)); |  | ||||||
|         } else if (qFather.getRight() == qd.getKey()) { |  | ||||||
|             qFather.setRight(delete(qd.getKey(), disc)); |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return delete; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private Pair<KDNode, Integer> max(KDNode node, int disc, int _disc) { |  | ||||||
|         int discNext = (_disc + 1) % dims; |  | ||||||
|         if (_disc == disc) { |  | ||||||
|             KDNode child = node.getLeft(); |  | ||||||
|             if (child != null) { |  | ||||||
|                 return max(child, disc, discNext); |  | ||||||
|             } |  | ||||||
|         } else if (node.getLeft() != null || node.getRight() != null) { |  | ||||||
|             Pair<KDNode, Integer> left = null, right = null; |  | ||||||
|             if (node.getLeft() != null) |  | ||||||
|                 left = max(node.getLeft(), disc, discNext); |  | ||||||
|             if (node.getRight() != null) |  | ||||||
|                 right = max(node.getRight(), disc, discNext); |  | ||||||
|             if (left != null && right != null) { |  | ||||||
|                 double pointLeft = left.getKey().getPoint().getDouble(disc); |  | ||||||
|                 double pointRight = right.getKey().getPoint().getDouble(disc); |  | ||||||
|                 if (pointLeft > pointRight) |  | ||||||
|                     return left; |  | ||||||
|                 else |  | ||||||
|                     return right; |  | ||||||
|             } else if (left != null) |  | ||||||
|                 return left; |  | ||||||
|             else |  | ||||||
|                 return right; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return Pair.of(node, _disc); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private Pair<KDNode, Integer> min(KDNode node, int disc, int _disc) { |  | ||||||
|         int discNext = (_disc + 1) % dims; |  | ||||||
|         if (_disc == disc) { |  | ||||||
|             KDNode child = node.getLeft(); |  | ||||||
|             if (child != null) { |  | ||||||
|                 return min(child, disc, discNext); |  | ||||||
|             } |  | ||||||
|         } else if (node.getLeft() != null || node.getRight() != null) { |  | ||||||
|             Pair<KDNode, Integer> left = null, right = null; |  | ||||||
|             if (node.getLeft() != null) |  | ||||||
|                 left = min(node.getLeft(), disc, discNext); |  | ||||||
|             if (node.getRight() != null) |  | ||||||
|                 right = min(node.getRight(), disc, discNext); |  | ||||||
|             if (left != null && right != null) { |  | ||||||
|                 double pointLeft = left.getKey().getPoint().getDouble(disc); |  | ||||||
|                 double pointRight = right.getKey().getPoint().getDouble(disc); |  | ||||||
|                 if (pointLeft < pointRight) |  | ||||||
|                     return left; |  | ||||||
|                 else |  | ||||||
|                     return right; |  | ||||||
|             } else if (left != null) |  | ||||||
|                 return left; |  | ||||||
|             else |  | ||||||
|                 return right; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return Pair.of(node, _disc); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The number of elements in the tree |  | ||||||
|      * @return the number of elements in the tree |  | ||||||
|      */ |  | ||||||
|     public int size() { |  | ||||||
|         return size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private int successor(KDNode node, INDArray point, int disc) { |  | ||||||
|         for (int i = disc; i < dims; i++) { |  | ||||||
|             double pointI = point.getDouble(i); |  | ||||||
|             double nodePointI = node.getPoint().getDouble(i); |  | ||||||
|             if (pointI < nodePointI) |  | ||||||
|                 return LESS; |  | ||||||
|             else if (pointI > nodePointI) |  | ||||||
|                 return GREATER; |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         throw new IllegalStateException("Point is equal!"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private static class KDNode { |  | ||||||
|         private INDArray point; |  | ||||||
|         private KDNode left, right, parent; |  | ||||||
| 
 |  | ||||||
|         public KDNode(INDArray point) { |  | ||||||
|             this.point = point; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public INDArray getPoint() { |  | ||||||
|             return point; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public KDNode getLeft() { |  | ||||||
|             return left; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public void setLeft(KDNode left) { |  | ||||||
|             this.left = left; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public KDNode getRight() { |  | ||||||
|             return right; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public void setRight(KDNode right) { |  | ||||||
|             this.right = right; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public KDNode getParent() { |  | ||||||
|             return parent; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         public void setParent(KDNode parent) { |  | ||||||
|             this.parent = parent; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,109 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.kmeans; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.strategy.ClusteringStrategy; |  | ||||||
| import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class KMeansClustering extends BaseClusteringAlgorithm { |  | ||||||
| 
 |  | ||||||
|     private static final long serialVersionUID = 8476951388145944776L; |  | ||||||
|     private static final double VARIATION_TOLERANCE= 1e-4; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusteringStrategy |  | ||||||
|      */ |  | ||||||
|     protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) { |  | ||||||
|         super(clusteringStrategy, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Setup a kmeans instance |  | ||||||
|      * @param clusterCount the number of clusters |  | ||||||
|      * @param maxIterationCount the max number of iterations |  | ||||||
|      *                          to run kmeans |  | ||||||
|      * @param distanceFunction the distance function to use for grouping |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, |  | ||||||
|                     boolean inverse, boolean useKMeansPlusPlus) { |  | ||||||
|         ClusteringStrategy clusteringStrategy = |  | ||||||
|                         FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse); |  | ||||||
|         clusteringStrategy.endWhenIterationCountEquals(maxIterationCount); |  | ||||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterCount |  | ||||||
|      * @param minDistributionVariationRate |  | ||||||
|      * @param distanceFunction |  | ||||||
|      * @param allowEmptyClusters |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, |  | ||||||
|                     boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) { |  | ||||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse) |  | ||||||
|                         .endWhenDistributionVariationRateLessThan(minDistributionVariationRate); |  | ||||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Setup a kmeans instance |  | ||||||
|      * @param clusterCount the number of clusters |  | ||||||
|      * @param maxIterationCount the max number of iterations |  | ||||||
|      *                          to run kmeans |  | ||||||
|      * @param distanceFunction the distance function to use for grouping |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) { |  | ||||||
|         return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterCount |  | ||||||
|      * @param minDistributionVariationRate |  | ||||||
|      * @param distanceFunction |  | ||||||
|      * @param allowEmptyClusters |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction, |  | ||||||
|                     boolean allowEmptyClusters, boolean useKMeansPlusPlus) { |  | ||||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); |  | ||||||
|         clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate); |  | ||||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static KMeansClustering setup(int clusterCount, Distance distanceFunction, |  | ||||||
|                                          boolean allowEmptyClusters, boolean useKMeansPlusPlus) { |  | ||||||
|         ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false); |  | ||||||
|         clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE); |  | ||||||
|         return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,88 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.lsh; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| public interface LSH { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns an instance of the distance measure associated to the LSH family of this implementation. |  | ||||||
|      * Beware, hashing families and their amplification constructs are distance-specific. |  | ||||||
|      */ |  | ||||||
|      String getDistanceMeasure(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction |  | ||||||
|      * |  | ||||||
|      * denoting hashLength by h, |  | ||||||
|      * amplifies a (d1, d2, p1, p2) hash family into a |  | ||||||
|      *                   (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h) |  | ||||||
|      * |  | ||||||
|      * @return the length of the hash in the AND construction used by this index |  | ||||||
|      */ |  | ||||||
|      int getHashLength(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * denoting numTables by n, |  | ||||||
|      * amplifies a (d1, d2, p1, p2) hash family into a |  | ||||||
|      *                   (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n) |  | ||||||
|      * |  | ||||||
|      * @return the # of hash tables in the OR construction used by this index |  | ||||||
|      */ |  | ||||||
|      int getNumTables(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return The dimension of the index vectors and queries |  | ||||||
|      */ |  | ||||||
|      int getInDimension(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Populates the index with data vectors. |  | ||||||
|      * @param data the vectors to index |  | ||||||
|      */ |  | ||||||
|      void makeIndex(INDArray data); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns the set of all vectors that could approximately be considered negihbors of the query, |  | ||||||
|      * without selection on the basis of distance or number of neighbors. |  | ||||||
|      * @param query a  vector to find neighbors for |  | ||||||
|      * @return its approximate neighbors, unfiltered |  | ||||||
|      */ |  | ||||||
|      INDArray bucket(INDArray query); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns the approximate neighbors within a distance bound. |  | ||||||
|      * @param query a vector to find neighbors for |  | ||||||
|      * @param maxRange the maximum distance between results and the query |  | ||||||
|      * @return approximate neighbors within the distance bounds |  | ||||||
|      */ |  | ||||||
|      INDArray search(INDArray query, double maxRange); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns the approximate neighbors within a k-closest bound |  | ||||||
|      * @param query a vector to find neighbors for |  | ||||||
|      * @param k the maximum number of closest neighbors to return |  | ||||||
|      * @return at most k neighbors of the query, ordered by increasing distance |  | ||||||
|      */ |  | ||||||
|      INDArray search(INDArray query, int k); |  | ||||||
| } |  | ||||||
| @ -1,227 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.lsh; |  | ||||||
| 
 |  | ||||||
| import lombok.Getter; |  | ||||||
| import lombok.val; |  | ||||||
| import org.nd4j.common.base.Preconditions; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.transforms.same.Sign; |  | ||||||
| import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution; |  | ||||||
| import org.nd4j.linalg.api.rng.Random; |  | ||||||
| import org.nd4j.linalg.exception.ND4JIllegalStateException; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.linalg.indexing.BooleanIndexing; |  | ||||||
| import org.nd4j.linalg.indexing.conditions.Conditions; |  | ||||||
| import org.nd4j.linalg.ops.transforms.Transforms; |  | ||||||
| 
 |  | ||||||
| import java.util.Arrays; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| public class RandomProjectionLSH implements LSH { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public String getDistanceMeasure(){ |  | ||||||
|         return "cosinedistance"; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Getter private int hashLength; |  | ||||||
| 
 |  | ||||||
|     @Getter private int numTables; |  | ||||||
| 
 |  | ||||||
|     @Getter private int inDimension; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Getter private double radius; |  | ||||||
| 
 |  | ||||||
|     INDArray randomProjection; |  | ||||||
| 
 |  | ||||||
|     INDArray index; |  | ||||||
| 
 |  | ||||||
|     INDArray indexData; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private INDArray gaussianRandomMatrix(int[] shape, Random rng){ |  | ||||||
|         INDArray res = Nd4j.create(shape); |  | ||||||
| 
 |  | ||||||
|         GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0])); |  | ||||||
| 
 |  | ||||||
|         Nd4j.getExecutioner().exec(op1, rng); |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){ |  | ||||||
|         this(hashLength, numTables, inDimension, radius, Nd4j.getRandom()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Creates a locality-sensitive hashing index for the cosine distance, |  | ||||||
|      * a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification |  | ||||||
|      * |  | ||||||
|      * @param hashLength the length of the compared hash in an AND construction, |  | ||||||
|      * @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple |  | ||||||
|      *                  probes of Panigraphi (op. cit). |  | ||||||
|      * @param inDimension the dimendionality of the points being indexed |  | ||||||
|      * @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction |  | ||||||
|      * @param rng a Random object to draw samples from |  | ||||||
|      */ |  | ||||||
|     public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){ |  | ||||||
|         this.hashLength = hashLength; |  | ||||||
|         this.numTables = numTables; |  | ||||||
|         this.inDimension = inDimension; |  | ||||||
|         this.radius = radius; |  | ||||||
|         randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * This picks uniformaly distributed random points on the unit of a sphere using the method of: |  | ||||||
|      * |  | ||||||
|      * An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere |  | ||||||
|      * JS Hicks, RF Wheeling - Communications of the ACM, 1959 |  | ||||||
|      * @param data a query to generate multiple probes for |  | ||||||
|      * @return `numTables` |  | ||||||
|      */ |  | ||||||
|     public INDArray entropy(INDArray data){ |  | ||||||
| 
 |  | ||||||
|         INDArray data2 = |  | ||||||
|                     Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius)); |  | ||||||
| 
 |  | ||||||
|         INDArray norms = Nd4j.norm2(data2.dup(), -1); |  | ||||||
| 
 |  | ||||||
|         Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms); |  | ||||||
| 
 |  | ||||||
|         data2.diviColumnVector(norms); |  | ||||||
|         data2.addiRowVector(data); |  | ||||||
|         return data2; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns hash values for a particular query |  | ||||||
|      * @param data a query vector |  | ||||||
|      * @return its hashed value |  | ||||||
|      */ |  | ||||||
|     public INDArray hash(INDArray data) { |  | ||||||
|         if (data.shape()[1] != inDimension){ |  | ||||||
|             throw new ND4JIllegalStateException( |  | ||||||
|                     String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d", |  | ||||||
|                             Arrays.toString(data.shape()), inDimension)); |  | ||||||
|         } |  | ||||||
|         INDArray projected = data.mmul(randomProjection); |  | ||||||
|         INDArray res = Nd4j.getExecutioner().exec(new Sign(projected)); |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it. |  | ||||||
|      * @param data the vectors to index |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public void makeIndex(INDArray data) { |  | ||||||
|         index = hash(data); |  | ||||||
|         indexData = data; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // data elements in the same bucket as the query, without entropy |  | ||||||
|     INDArray rawBucketOf(INDArray query){ |  | ||||||
|         INDArray pattern = hash(query); |  | ||||||
| 
 |  | ||||||
|         INDArray res = Nd4j.zeros(DataType.BOOL, index.shape()); |  | ||||||
|         Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1)); |  | ||||||
|         return res.castTo(Nd4j.defaultFloatingPointType()).min(-1); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray bucket(INDArray query) { |  | ||||||
|         INDArray queryRes = rawBucketOf(query); |  | ||||||
| 
 |  | ||||||
|         if(numTables > 1) { |  | ||||||
|             INDArray entropyQueries = entropy(query); |  | ||||||
| 
 |  | ||||||
|             // loop, addi + conditionalreplace -> poor man's OR function |  | ||||||
|             for (int i = 0; i < numTables; i++) { |  | ||||||
|                 INDArray row = entropyQueries.getRow(i, true); |  | ||||||
|                 queryRes.addi(rawBucketOf(row)); |  | ||||||
|             } |  | ||||||
|             BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return queryRes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     // data elements in the same entropy bucket as the query, |  | ||||||
|     INDArray bucketData(INDArray query){ |  | ||||||
|         INDArray mask = bucket(query); |  | ||||||
|         int nRes = mask.sum(0).getInt(0); |  | ||||||
|         INDArray res = Nd4j.create(new int[] {nRes, inDimension}); |  | ||||||
|         int j = 0; |  | ||||||
|         for (int i = 0; i < nRes; i++){ |  | ||||||
|             while (mask.getInt(j) == 0 && j < mask.length() - 1) { |  | ||||||
|                 j += 1; |  | ||||||
|             } |  | ||||||
|             if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j)); |  | ||||||
|             j += 1; |  | ||||||
|         } |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray search(INDArray query, double maxRange) { |  | ||||||
|         if (maxRange < 0) |  | ||||||
|             throw new IllegalArgumentException("ANN search should have a positive maximum search radius"); |  | ||||||
| 
 |  | ||||||
|         INDArray bucketData = bucketData(query); |  | ||||||
|         INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); |  | ||||||
|         INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); |  | ||||||
| 
 |  | ||||||
|         INDArray shuffleIndexes = idxs[0]; |  | ||||||
|         INDArray sortedDistances = idxs[1]; |  | ||||||
|         int accepted = 0; |  | ||||||
|         while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1; |  | ||||||
| 
 |  | ||||||
|         INDArray res = Nd4j.create(new int[] {accepted, inDimension}); |  | ||||||
|         for(int i = 0; i < accepted; i++){ |  | ||||||
|             res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); |  | ||||||
|         } |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public INDArray search(INDArray query, int k) { |  | ||||||
|         if (k < 1) |  | ||||||
|             throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor"); |  | ||||||
| 
 |  | ||||||
|         INDArray bucketData = bucketData(query); |  | ||||||
|         INDArray distances = Transforms.allCosineDistances(bucketData, query, -1); |  | ||||||
|         INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true); |  | ||||||
| 
 |  | ||||||
|         INDArray shuffleIndexes = idxs[0]; |  | ||||||
|         INDArray sortedDistances = idxs[1]; |  | ||||||
|         val accepted = Math.min(k, sortedDistances.shape()[1]); |  | ||||||
| 
 |  | ||||||
|         INDArray res = Nd4j.create(accepted, inDimension); |  | ||||||
|         for(int i = 0; i < accepted; i++){ |  | ||||||
|             res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i))); |  | ||||||
|         } |  | ||||||
|         return res; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,38 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.optimisation; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.AllArgsConstructor; |  | ||||||
| import lombok.Data; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class ClusteringOptimization implements Serializable { |  | ||||||
| 
 |  | ||||||
|     private ClusteringOptimizationType type; |  | ||||||
|     private double value; |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,28 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.optimisation; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public enum ClusteringOptimizationType { |  | ||||||
|     MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT |  | ||||||
| } |  | ||||||
| @ -1,115 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.quadtree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| public class Cell implements Serializable { |  | ||||||
|     private double x, y, hw, hh; |  | ||||||
| 
 |  | ||||||
|     public Cell(double x, double y, double hw, double hh) { |  | ||||||
|         this.x = x; |  | ||||||
|         this.y = y; |  | ||||||
|         this.hw = hw; |  | ||||||
|         this.hh = hh; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Whether the given point is contained |  | ||||||
|      * within this cell |  | ||||||
|      * @param point the point to check |  | ||||||
|      * @return true if the point is contained, false otherwise |  | ||||||
|      */ |  | ||||||
|     public boolean containsPoint(INDArray point) { |  | ||||||
|         double first = point.getDouble(0), second = point.getDouble(1); |  | ||||||
|         return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean equals(Object o) { |  | ||||||
|         if (this == o) |  | ||||||
|             return true; |  | ||||||
|         if (!(o instanceof Cell)) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         Cell cell = (Cell) o; |  | ||||||
| 
 |  | ||||||
|         if (Double.compare(cell.hh, hh) != 0) |  | ||||||
|             return false; |  | ||||||
|         if (Double.compare(cell.hw, hw) != 0) |  | ||||||
|             return false; |  | ||||||
|         if (Double.compare(cell.x, x) != 0) |  | ||||||
|             return false; |  | ||||||
|         return Double.compare(cell.y, y) == 0; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int hashCode() { |  | ||||||
|         int result; |  | ||||||
|         long temp; |  | ||||||
|         temp = Double.doubleToLongBits(x); |  | ||||||
|         result = (int) (temp ^ (temp >>> 32)); |  | ||||||
|         temp = Double.doubleToLongBits(y); |  | ||||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); |  | ||||||
|         temp = Double.doubleToLongBits(hw); |  | ||||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); |  | ||||||
|         temp = Double.doubleToLongBits(hh); |  | ||||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getX() { |  | ||||||
|         return x; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setX(double x) { |  | ||||||
|         this.x = x; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getY() { |  | ||||||
|         return y; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setY(double y) { |  | ||||||
|         this.y = y; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getHw() { |  | ||||||
|         return hw; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setHw(double hw) { |  | ||||||
|         this.hw = hw; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getHh() { |  | ||||||
|         return hh; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setHh(double hh) { |  | ||||||
|         this.hh = hh; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,383 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.quadtree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.shade.guava.util.concurrent.AtomicDouble; |  | ||||||
| import org.apache.commons.math3.util.FastMath; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| import static java.lang.Math.max; |  | ||||||
| 
 |  | ||||||
| public class QuadTree implements Serializable { |  | ||||||
|     private QuadTree parent, northWest, northEast, southWest, southEast; |  | ||||||
|     private boolean isLeaf = true; |  | ||||||
|     private int size, cumSize; |  | ||||||
|     private Cell boundary; |  | ||||||
|     static final int QT_NO_DIMS = 2; |  | ||||||
|     static final int QT_NODE_CAPACITY = 1; |  | ||||||
|     private INDArray buf = Nd4j.create(QT_NO_DIMS); |  | ||||||
|     private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS); |  | ||||||
|     private int[] index = new int[QT_NODE_CAPACITY]; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Pass in a matrix |  | ||||||
|      * @param data |  | ||||||
|      */ |  | ||||||
|     public QuadTree(INDArray data) { |  | ||||||
|         INDArray meanY = data.mean(0); |  | ||||||
|         INDArray minY = data.min(0); |  | ||||||
|         INDArray maxY = data.max(0); |  | ||||||
|         init(data, meanY.getDouble(0), meanY.getDouble(1), |  | ||||||
|                         max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0)) |  | ||||||
|                                         + Nd4j.EPS_THRESHOLD, |  | ||||||
|                         max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1)) |  | ||||||
|                                         + Nd4j.EPS_THRESHOLD); |  | ||||||
|         fill(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree(QuadTree parent, INDArray data, Cell boundary) { |  | ||||||
|         this.parent = parent; |  | ||||||
|         this.boundary = boundary; |  | ||||||
|         this.data = data; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree(Cell boundary) { |  | ||||||
|         this.boundary = boundary; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void init(INDArray data, double x, double y, double hw, double hh) { |  | ||||||
|         boundary = new Cell(x, y, hw, hh); |  | ||||||
|         this.data = data; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void fill() { |  | ||||||
|         for (int i = 0; i < data.rows(); i++) |  | ||||||
|             insert(i); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns the cell of this element |  | ||||||
|      * |  | ||||||
|      * @param coordinates |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     protected QuadTree findIndex(INDArray coordinates) { |  | ||||||
| 
 |  | ||||||
|         // Compute the sector for the coordinates |  | ||||||
|         boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2)); |  | ||||||
|         boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2)); |  | ||||||
| 
 |  | ||||||
|         // top left |  | ||||||
|         QuadTree index = getNorthWest(); |  | ||||||
|         if (left) { |  | ||||||
|             // left side |  | ||||||
|             if (!top) { |  | ||||||
|                 // bottom left |  | ||||||
|                 index = getSouthWest(); |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             // right side |  | ||||||
|             if (top) { |  | ||||||
|                 // top right |  | ||||||
|                 index = getNorthEast(); |  | ||||||
|             } else { |  | ||||||
|                 // bottom right |  | ||||||
|                 index = getSouthEast(); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Insert an index of the data in to the tree |  | ||||||
|      * @param newIndex the index to insert in to the tree |  | ||||||
|      * @return whether the index was inserted or not |  | ||||||
|      */ |  | ||||||
|     public boolean insert(int newIndex) { |  | ||||||
|         // Ignore objects which do not belong in this quad tree |  | ||||||
|         INDArray point = data.slice(newIndex); |  | ||||||
|         if (!boundary.containsPoint(point)) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         cumSize++; |  | ||||||
|         double mult1 = (double) (cumSize - 1) / (double) cumSize; |  | ||||||
|         double mult2 = 1.0 / (double) cumSize; |  | ||||||
| 
 |  | ||||||
|         centerOfMass.muli(mult1); |  | ||||||
|         centerOfMass.addi(point.mul(mult2)); |  | ||||||
| 
 |  | ||||||
|         // If there is space in this quad tree and it is a leaf, add the object here |  | ||||||
|         if (isLeaf() && size < QT_NODE_CAPACITY) { |  | ||||||
|             index[size] = newIndex; |  | ||||||
|             size++; |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //duplicate point |  | ||||||
|         if (size > 0) { |  | ||||||
|             for (int i = 0; i < size; i++) { |  | ||||||
|                 INDArray compPoint = data.slice(index[i]); |  | ||||||
|                 if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1)) |  | ||||||
|                     return true; |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         // If this Node has already been subdivided just add the elements to the |  | ||||||
|         // appropriate cell |  | ||||||
|         if (!isLeaf()) { |  | ||||||
|             QuadTree index = findIndex(point); |  | ||||||
|             index.insert(newIndex); |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (isLeaf()) |  | ||||||
|             subDivide(); |  | ||||||
| 
 |  | ||||||
|         boolean ret = insertIntoOneOf(newIndex); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private boolean insertIntoOneOf(int index) { |  | ||||||
|         boolean success = false; |  | ||||||
|         success = northWest.insert(index); |  | ||||||
|         if (!success) |  | ||||||
|             success = northEast.insert(index); |  | ||||||
|         if (!success) |  | ||||||
|             success = southWest.insert(index); |  | ||||||
|         if (!success) |  | ||||||
|             success = southEast.insert(index); |  | ||||||
|         return success; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Returns whether the tree is consistent or not |  | ||||||
|      * @return whether the tree is consistent or not |  | ||||||
|      */ |  | ||||||
|     public boolean isCorrect() { |  | ||||||
| 
 |  | ||||||
|         for (int n = 0; n < size; n++) { |  | ||||||
|             INDArray point = data.slice(index[n]); |  | ||||||
|             if (!boundary.containsPoint(point)) |  | ||||||
|                 return false; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect() |  | ||||||
|                         && southEast.isCorrect(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      *  Create four children |  | ||||||
|      *  which fully divide this cell |  | ||||||
|      *  into four quads of equal area |  | ||||||
|      */ |  | ||||||
|     public void subDivide() { |  | ||||||
|         northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), |  | ||||||
|                         boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); |  | ||||||
|         northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), |  | ||||||
|                         boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); |  | ||||||
|         southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(), |  | ||||||
|                         boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); |  | ||||||
|         southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(), |  | ||||||
|                         boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh())); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Compute non edge forces using barnes hut |  | ||||||
|      * @param pointIndex |  | ||||||
|      * @param theta |  | ||||||
|      * @param negativeForce |  | ||||||
|      * @param sumQ |  | ||||||
|      */ |  | ||||||
|     public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { |  | ||||||
|         // Make sure that we spend no time on empty nodes or self-interactions |  | ||||||
|         if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) |  | ||||||
|             return; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         // Compute distance between point and center-of-mass |  | ||||||
|         buf.assign(data.slice(pointIndex)).subi(centerOfMass); |  | ||||||
| 
 |  | ||||||
|         double D = Nd4j.getBlasWrapper().dot(buf, buf); |  | ||||||
| 
 |  | ||||||
|         // Check whether we can use this node as a "summary" |  | ||||||
|         if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) { |  | ||||||
| 
 |  | ||||||
|             // Compute and add t-SNE force between point and current node |  | ||||||
|             double Q = 1.0 / (1.0 + D); |  | ||||||
|             double mult = cumSize * Q; |  | ||||||
|             sumQ.addAndGet(mult); |  | ||||||
|             mult *= Q; |  | ||||||
|             negativeForce.addi(buf.mul(mult)); |  | ||||||
| 
 |  | ||||||
|         } else { |  | ||||||
| 
 |  | ||||||
|             // Recursively apply Barnes-Hut to children |  | ||||||
|             northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); |  | ||||||
|             northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); |  | ||||||
|             southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); |  | ||||||
|             southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param rowP a vector |  | ||||||
|      * @param colP |  | ||||||
|      * @param valP |  | ||||||
|      * @param N |  | ||||||
|      * @param posF |  | ||||||
|      */ |  | ||||||
|     public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { |  | ||||||
|         if (!rowP.isVector()) |  | ||||||
|             throw new IllegalArgumentException("RowP must be a vector"); |  | ||||||
| 
 |  | ||||||
|         // Loop over all edges in the graph |  | ||||||
|         double D; |  | ||||||
|         for (int n = 0; n < N; n++) { |  | ||||||
|             for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { |  | ||||||
| 
 |  | ||||||
|                 // Compute pairwise distance and Q-value |  | ||||||
|                 buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i))); |  | ||||||
| 
 |  | ||||||
|                 D = Nd4j.getBlasWrapper().dot(buf, buf); |  | ||||||
|                 D = valP.getDouble(i) / D; |  | ||||||
| 
 |  | ||||||
|                 // Sum positive force |  | ||||||
|                 posF.slice(n).addi(buf.mul(D)); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The depth of the node |  | ||||||
|      * @return the depth of the node |  | ||||||
|      */ |  | ||||||
|     public int depth() { |  | ||||||
|         if (isLeaf()) |  | ||||||
|             return 1; |  | ||||||
|         return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth())); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getCenterOfMass() { |  | ||||||
|         return centerOfMass; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setCenterOfMass(INDArray centerOfMass) { |  | ||||||
|         this.centerOfMass = centerOfMass; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree getParent() { |  | ||||||
|         return parent; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setParent(QuadTree parent) { |  | ||||||
|         this.parent = parent; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree getNorthWest() { |  | ||||||
|         return northWest; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setNorthWest(QuadTree northWest) { |  | ||||||
|         this.northWest = northWest; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree getNorthEast() { |  | ||||||
|         return northEast; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setNorthEast(QuadTree northEast) { |  | ||||||
|         this.northEast = northEast; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree getSouthWest() { |  | ||||||
|         return southWest; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSouthWest(QuadTree southWest) { |  | ||||||
|         this.southWest = southWest; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public QuadTree getSouthEast() { |  | ||||||
|         return southEast; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSouthEast(QuadTree southEast) { |  | ||||||
|         this.southEast = southEast; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isLeaf() { |  | ||||||
|         return isLeaf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setLeaf(boolean isLeaf) { |  | ||||||
|         this.isLeaf = isLeaf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getSize() { |  | ||||||
|         return size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setSize(int size) { |  | ||||||
|         this.size = size; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getCumSize() { |  | ||||||
|         return cumSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setCumSize(int cumSize) { |  | ||||||
|         this.cumSize = cumSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Cell getBoundary() { |  | ||||||
|         return boundary; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setBoundary(Cell boundary) { |  | ||||||
|         this.boundary = boundary; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,104 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.randomprojection; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @Data |  | ||||||
| public class RPForest { |  | ||||||
| 
 |  | ||||||
|     private int numTrees; |  | ||||||
|     private List<RPTree> trees; |  | ||||||
|     private INDArray data; |  | ||||||
|     private int maxSize = 1000; |  | ||||||
|     private String similarityFunction; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create the rp forest with the specified number of trees |  | ||||||
|      * @param numTrees the number of trees in the forest |  | ||||||
|      * @param maxSize the max size of each tree |  | ||||||
|      * @param similarityFunction the distance function to use |  | ||||||
|      */ |  | ||||||
|     public RPForest(int numTrees,int maxSize,String similarityFunction) { |  | ||||||
|         this.numTrees = numTrees; |  | ||||||
|         this.maxSize = maxSize; |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         trees = new ArrayList<>(numTrees); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Build the trees from the given dataset |  | ||||||
|      * @param x the input dataset (should be a 2d matrix) |  | ||||||
|      */ |  | ||||||
|     public void fit(INDArray x) { |  | ||||||
|         this.data = x; |  | ||||||
|         for(int i = 0; i < numTrees; i++) { |  | ||||||
|             RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction); |  | ||||||
|             tree.buildTree(x); |  | ||||||
|             trees.add(tree); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get all candidates relative to a specific datapoint. |  | ||||||
|      * @param input |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public INDArray getAllCandidates(INDArray input) { |  | ||||||
|         return RPUtils.getAllCandidates(input,trees,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query results up to length n |  | ||||||
|      * nearest neighbors |  | ||||||
|      * @param toQuery the query item |  | ||||||
|      * @param n the number of nearest neighbors for the given data point |  | ||||||
|      * @return the indices for the nearest neighbors |  | ||||||
|      */ |  | ||||||
|     public INDArray queryAll(INDArray toQuery,int n) { |  | ||||||
|         return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query all with the distances |  | ||||||
|      * sorted by index |  | ||||||
|      * @param query the query vector |  | ||||||
|      * @param numResults the number of results to return |  | ||||||
|      * @return a list of samples |  | ||||||
|      */ |  | ||||||
|     public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) { |  | ||||||
|         return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,57 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.randomprojection; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class RPHyperPlanes { |  | ||||||
|     private int dim; |  | ||||||
|     private INDArray wholeHyperPlane; |  | ||||||
| 
 |  | ||||||
|     public RPHyperPlanes(int dim) { |  | ||||||
|         this.dim = dim; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getHyperPlaneAt(int depth) { |  | ||||||
|         if(wholeHyperPlane.isVector()) |  | ||||||
|             return wholeHyperPlane; |  | ||||||
|         return wholeHyperPlane.slice(depth); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Add a new random element to the hyper plane. |  | ||||||
|      */ |  | ||||||
|     public void addRandomHyperPlane() { |  | ||||||
|         INDArray newPlane = Nd4j.randn(new int[] {1,dim}); |  | ||||||
|         newPlane.divi(newPlane.normmaxNumber()); |  | ||||||
|         if(wholeHyperPlane == null) |  | ||||||
|             wholeHyperPlane = newPlane; |  | ||||||
|         else { |  | ||||||
|             wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,48 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.randomprojection; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.concurrent.Future; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class RPNode { |  | ||||||
|     private int depth; |  | ||||||
|     private RPNode left,right; |  | ||||||
|     private Future<RPNode> leftFuture,rightFuture; |  | ||||||
|     private List<Integer> indices; |  | ||||||
|     private double median; |  | ||||||
|     private RPTree tree; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public RPNode(RPTree tree,int depth) { |  | ||||||
|         this.depth = depth; |  | ||||||
|         this.tree = tree; |  | ||||||
|         indices = new ArrayList<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,130 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.randomprojection; |  | ||||||
| 
 |  | ||||||
| import lombok.Builder; |  | ||||||
| import lombok.Data; |  | ||||||
| import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; |  | ||||||
| import org.nd4j.linalg.api.memory.enums.*; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.concurrent.ExecutorService; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class RPTree { |  | ||||||
|     private RPNode root; |  | ||||||
|     private RPHyperPlanes rpHyperPlanes; |  | ||||||
|     private int dim; |  | ||||||
|     //also knows as leave size |  | ||||||
|     private int maxSize; |  | ||||||
|     private INDArray X; |  | ||||||
|     private String similarityFunction = "euclidean"; |  | ||||||
|     private WorkspaceConfiguration workspaceConfiguration; |  | ||||||
|     private ExecutorService searchExecutor; |  | ||||||
|     private int searchWorkers; |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param dim the dimension of the vectors |  | ||||||
|      * @param maxSize the max size of the leaves |  | ||||||
|      * |  | ||||||
|      */ |  | ||||||
|     @Builder |  | ||||||
|     public RPTree(int dim, int maxSize,String similarityFunction) { |  | ||||||
|         this.dim = dim; |  | ||||||
|         this.maxSize = maxSize; |  | ||||||
|         rpHyperPlanes = new RPHyperPlanes(dim); |  | ||||||
|         root = new RPNode(this,0); |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) |  | ||||||
|                 .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) |  | ||||||
|                 .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) |  | ||||||
|                 .policySpill(SpillPolicy.REALLOCATE).build(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param dim the dimension of the vectors |  | ||||||
|      * @param maxSize the max size of the leaves |  | ||||||
|      * |  | ||||||
|      */ |  | ||||||
|     public RPTree(int dim, int maxSize) { |  | ||||||
|        this(dim,maxSize,"euclidean"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      *  Build the tree with the given input data |  | ||||||
|      * @param x |  | ||||||
|      */ |  | ||||||
| 
 |  | ||||||
|     public void buildTree(INDArray x) { |  | ||||||
|         this.X = x; |  | ||||||
|         for(int i = 0; i < x.rows(); i++) { |  | ||||||
|             root.getIndices().add(i); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         RPUtils.buildTree(this,root,rpHyperPlanes, |  | ||||||
|                 x,maxSize,0,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public void addNodeAtIndex(int idx,INDArray toAdd) { |  | ||||||
|         RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction); |  | ||||||
|         query.getIndices().add(idx); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public List<RPNode> getLeaves() { |  | ||||||
|         List<RPNode> nodes = new ArrayList<>(); |  | ||||||
|         RPUtils.scanForLeaves(nodes,getRoot()); |  | ||||||
|         return nodes; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query all with the distances |  | ||||||
|      * sorted by index |  | ||||||
|      * @param query the query vector |  | ||||||
|      * @param numResults the number of results to return |  | ||||||
|      * @return a list of samples |  | ||||||
|      */ |  | ||||||
|     public List<Pair<Double, Integer>> queryWithDistances(INDArray query, int numResults) { |  | ||||||
|             return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray query(INDArray query,int numResults) { |  | ||||||
|         return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public List<Integer> getCandidates(INDArray target) { |  | ||||||
|         return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,481 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.randomprojection; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.shade.guava.primitives.Doubles; |  | ||||||
| import lombok.val; |  | ||||||
| import org.nd4j.autodiff.functions.DifferentialFunction; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.ReduceOp; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; |  | ||||||
| import org.nd4j.linalg.exception.ND4JIllegalArgumentException; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| 
 |  | ||||||
| import java.util.*; |  | ||||||
| 
 |  | ||||||
| public class RPUtils { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private static ThreadLocal<Map<String,DifferentialFunction>> functionInstances = new ThreadLocal<>(); |  | ||||||
| 
 |  | ||||||
|     public static <T extends DifferentialFunction> DifferentialFunction getOp(String name, |  | ||||||
|                                                                               INDArray x, |  | ||||||
|                                                                               INDArray y, |  | ||||||
|                                                                               INDArray result) { |  | ||||||
|         Map<String,DifferentialFunction> ops = functionInstances.get(); |  | ||||||
|         if(ops == null) { |  | ||||||
|             ops = new HashMap<>(); |  | ||||||
|             functionInstances.set(ops); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         boolean allDistances = x.length() != y.length(); |  | ||||||
| 
 |  | ||||||
|         switch(name) { |  | ||||||
|             case "cosinedistance": |  | ||||||
|                 if(!ops.containsKey(name) || ((CosineDistance)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     CosineDistance cosineDistance = new CosineDistance(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,cosineDistance); |  | ||||||
|                     return cosineDistance; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     CosineDistance cosineDistance = (CosineDistance) ops.get(name); |  | ||||||
|                     return cosineDistance; |  | ||||||
|                 } |  | ||||||
|             case "cosinesimilarity": |  | ||||||
|                 if(!ops.containsKey(name) || ((CosineSimilarity)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     CosineSimilarity cosineSimilarity = new CosineSimilarity(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,cosineSimilarity); |  | ||||||
|                     return cosineSimilarity; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     CosineSimilarity cosineSimilarity = (CosineSimilarity) ops.get(name); |  | ||||||
|                     cosineSimilarity.setX(x); |  | ||||||
|                     cosineSimilarity.setY(y); |  | ||||||
|                     cosineSimilarity.setZ(result); |  | ||||||
|                     return cosineSimilarity; |  | ||||||
| 
 |  | ||||||
|                 } |  | ||||||
|             case "manhattan": |  | ||||||
|                 if(!ops.containsKey(name) || ((ManhattanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     ManhattanDistance manhattanDistance = new ManhattanDistance(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,manhattanDistance); |  | ||||||
|                     return manhattanDistance; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     ManhattanDistance manhattanDistance = (ManhattanDistance) ops.get(name); |  | ||||||
|                     manhattanDistance.setX(x); |  | ||||||
|                     manhattanDistance.setY(y); |  | ||||||
|                     manhattanDistance.setZ(result); |  | ||||||
|                     return  manhattanDistance; |  | ||||||
|                 } |  | ||||||
|             case "jaccard": |  | ||||||
|                 if(!ops.containsKey(name) || ((JaccardDistance)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     JaccardDistance jaccardDistance = new JaccardDistance(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,jaccardDistance); |  | ||||||
|                     return jaccardDistance; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     JaccardDistance jaccardDistance = (JaccardDistance) ops.get(name); |  | ||||||
|                     jaccardDistance.setX(x); |  | ||||||
|                     jaccardDistance.setY(y); |  | ||||||
|                     jaccardDistance.setZ(result); |  | ||||||
|                     return jaccardDistance; |  | ||||||
|                 } |  | ||||||
|             case "hamming": |  | ||||||
|                 if(!ops.containsKey(name) || ((HammingDistance)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     HammingDistance hammingDistance = new HammingDistance(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,hammingDistance); |  | ||||||
|                     return hammingDistance; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     HammingDistance hammingDistance = (HammingDistance) ops.get(name); |  | ||||||
|                     hammingDistance.setX(x); |  | ||||||
|                     hammingDistance.setY(y); |  | ||||||
|                     hammingDistance.setZ(result); |  | ||||||
|                     return hammingDistance; |  | ||||||
|                 } |  | ||||||
|                 //euclidean |  | ||||||
|             default: |  | ||||||
|                 if(!ops.containsKey(name) || ((EuclideanDistance)ops.get(name)).isComplexAccumulation() != allDistances) { |  | ||||||
|                     EuclideanDistance euclideanDistance = new EuclideanDistance(x,y,result,allDistances); |  | ||||||
|                     ops.put(name,euclideanDistance); |  | ||||||
|                     return euclideanDistance; |  | ||||||
|                 } |  | ||||||
|                 else { |  | ||||||
|                     EuclideanDistance euclideanDistance = (EuclideanDistance) ops.get(name); |  | ||||||
|                     euclideanDistance.setX(x); |  | ||||||
|                     euclideanDistance.setY(y); |  | ||||||
|                     euclideanDistance.setZ(result); |  | ||||||
|                     return euclideanDistance; |  | ||||||
|                 } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query all trees using the given input and data |  | ||||||
|      * @param toQuery the query vector |  | ||||||
|      * @param X the input data to query |  | ||||||
|      * @param trees the trees to query |  | ||||||
|      * @param n the number of results to search for |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @return the indices (in order) in the ndarray |  | ||||||
|      */ |  | ||||||
|     public static List<Pair<Double,Integer>> queryAllWithDistances(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) { |  | ||||||
|         if(trees.isEmpty()) { |  | ||||||
|             throw new ND4JIllegalArgumentException("Trees is empty!"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction); |  | ||||||
|         val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); |  | ||||||
|         int numReturns = Math.min(n,sortedCandidates.size()); |  | ||||||
|         List<Pair<Double,Integer>> ret = new ArrayList<>(numReturns); |  | ||||||
|         for(int i = 0; i < numReturns; i++) { |  | ||||||
|             ret.add(sortedCandidates.get(i)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query all trees using the given input and data |  | ||||||
|      * @param toQuery the query vector |  | ||||||
|      * @param X the input data to query |  | ||||||
|      * @param trees the trees to query |  | ||||||
|      * @param n the number of results to search for |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @return the indices (in order) in the ndarray |  | ||||||
|      */ |  | ||||||
|     public static INDArray queryAll(INDArray toQuery,INDArray X,List<RPTree> trees,int n,String similarityFunction) { |  | ||||||
|         if(trees.isEmpty()) { |  | ||||||
|             throw new ND4JIllegalArgumentException("Trees is empty!"); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         List<Integer> candidates = getCandidates(toQuery, trees,similarityFunction); |  | ||||||
|         val sortedCandidates = sortCandidates(toQuery,X,candidates,similarityFunction); |  | ||||||
|         int numReturns = Math.min(n,sortedCandidates.size()); |  | ||||||
| 
 |  | ||||||
|         INDArray result = Nd4j.create(numReturns); |  | ||||||
|         for(int i = 0; i < numReturns; i++) { |  | ||||||
|             result.putScalar(i,sortedCandidates.get(i).getSecond()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the sorted distances given the |  | ||||||
|      * query vector, input data, given the list of possible search candidates |  | ||||||
|      * @param x the query vector |  | ||||||
|      * @param X the input data to use |  | ||||||
|      * @param candidates the possible search candidates |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @return the sorted distances |  | ||||||
|      */ |  | ||||||
|     public static List<Pair<Double,Integer>> sortCandidates(INDArray x,INDArray X, |  | ||||||
|                                                             List<Integer> candidates, |  | ||||||
|                                                             String similarityFunction) { |  | ||||||
|         int prevIdx = -1; |  | ||||||
|         List<Pair<Double,Integer>> ret = new ArrayList<>(); |  | ||||||
|         for(int i = 0; i < candidates.size(); i++) { |  | ||||||
|             if(candidates.get(i) != prevIdx) { |  | ||||||
|                 ret.add(Pair.of(computeDistance(similarityFunction,X.slice(candidates.get(i)),x),candidates.get(i))); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             prevIdx = i; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         Collections.sort(ret, new Comparator<Pair<Double, Integer>>() { |  | ||||||
|             @Override |  | ||||||
|             public int compare(Pair<Double, Integer> doubleIntegerPair, Pair<Double, Integer> t1) { |  | ||||||
|                 return Doubles.compare(doubleIntegerPair.getFirst(),t1.getFirst()); |  | ||||||
|             } |  | ||||||
|         }); |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the search candidates as indices given the input |  | ||||||
|      * and similarity function |  | ||||||
|      * @param x the input data to search with |  | ||||||
|      * @param trees the trees to search |  | ||||||
|      * @param similarityFunction the function to use for similarity |  | ||||||
|      * @return the list of indices as the search results |  | ||||||
|      */ |  | ||||||
|     public static INDArray getAllCandidates(INDArray x,List<RPTree> trees,String similarityFunction) { |  | ||||||
|         List<Integer> candidates = getCandidates(x,trees,similarityFunction); |  | ||||||
|         Collections.sort(candidates); |  | ||||||
| 
 |  | ||||||
|         int prevIdx = -1; |  | ||||||
|         int idxCount = 0; |  | ||||||
|         List<Pair<Integer,Integer>> scores = new ArrayList<>(); |  | ||||||
|         for(int i = 0; i < candidates.size(); i++) { |  | ||||||
|             if(candidates.get(i) == prevIdx) { |  | ||||||
|                 idxCount++; |  | ||||||
|             } |  | ||||||
|             else if(prevIdx != -1) { |  | ||||||
|                 scores.add(Pair.of(idxCount,prevIdx)); |  | ||||||
|                 idxCount = 1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             prevIdx = i; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         scores.add(Pair.of(idxCount,prevIdx)); |  | ||||||
| 
 |  | ||||||
|         INDArray arr = Nd4j.create(scores.size()); |  | ||||||
|         for(int i = 0; i < scores.size(); i++) { |  | ||||||
|             arr.putScalar(i,scores.get(i).getSecond()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return arr; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Get the search candidates as indices given the input |  | ||||||
|      * and similarity function |  | ||||||
|      * @param x the input data to search with |  | ||||||
|      * @param roots the trees to search |  | ||||||
|      * @param similarityFunction the function to use for similarity |  | ||||||
|      * @return the list of indices as the search results |  | ||||||
|      */ |  | ||||||
|     public static List<Integer> getCandidates(INDArray x,List<RPTree> roots,String similarityFunction) { |  | ||||||
|         Set<Integer> ret = new LinkedHashSet<>(); |  | ||||||
|         for(RPTree tree : roots) { |  | ||||||
|             RPNode root = tree.getRoot(); |  | ||||||
|             RPNode query = query(root,tree.getRpHyperPlanes(),x,similarityFunction); |  | ||||||
|             ret.addAll(query.getIndices()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return new ArrayList<>(ret); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Query the tree starting from the given node |  | ||||||
|      * using the given hyper plane and similarity function |  | ||||||
|      * @param from the node to start from |  | ||||||
|      * @param planes the hyper plane to query |  | ||||||
|      * @param x the input data |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @return the leaf node representing the given query from a |  | ||||||
|      * search in the tree |  | ||||||
|      */ |  | ||||||
|     public static  RPNode query(RPNode from,RPHyperPlanes planes,INDArray x,String similarityFunction) { |  | ||||||
|         if(from.getLeft() == null && from.getRight() == null) { |  | ||||||
|             return from; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         INDArray hyperPlane = planes.getHyperPlaneAt(from.getDepth()); |  | ||||||
|         double dist = computeDistance(similarityFunction,x,hyperPlane); |  | ||||||
|         if(dist <= from.getMedian()) { |  | ||||||
|             return query(from.getLeft(),planes,x,similarityFunction); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         else { |  | ||||||
|             return query(from.getRight(),planes,x,similarityFunction); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Compute the distance between 2 vectors |  | ||||||
|      * given a function name. Valid function names: |  | ||||||
|      * euclidean: euclidean distance |  | ||||||
|      * cosinedistance: cosine distance |  | ||||||
|      * cosine similarity: cosine similarity |  | ||||||
|      * manhattan: manhattan distance |  | ||||||
|      * jaccard: jaccard distance |  | ||||||
|      * hamming: hamming distance |  | ||||||
|      * @param function the function to use (default euclidean distance) |  | ||||||
|      * @param x the first vector |  | ||||||
|      * @param y the second vector |  | ||||||
|      * @return the distance between the 2 vectors given the inputs |  | ||||||
|      */ |  | ||||||
|     public static INDArray computeDistanceMulti(String function,INDArray x,INDArray y,INDArray result) { |  | ||||||
|         ReduceOp op = (ReduceOp) getOp(function, x, y, result); |  | ||||||
|         op.setDimensions(1); |  | ||||||
|         Nd4j.getExecutioner().exec(op); |  | ||||||
|         return op.z(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Compute the distance between 2 vectors |  | ||||||
|      * given a function name. Valid function names: |  | ||||||
|      * euclidean: euclidean distance |  | ||||||
|      * cosinedistance: cosine distance |  | ||||||
|      * cosine similarity: cosine similarity |  | ||||||
|      * manhattan: manhattan distance |  | ||||||
|      * jaccard: jaccard distance |  | ||||||
|      * hamming: hamming distance |  | ||||||
|      * @param function the function to use (default euclidean distance) |  | ||||||
|      * @param x the first vector |  | ||||||
|      * @param y the second vector |  | ||||||
|      * @return the distance between the 2 vectors given the inputs |  | ||||||
|      */ |  | ||||||
|     public static double computeDistance(String function,INDArray x,INDArray y,INDArray result) { |  | ||||||
|         ReduceOp op = (ReduceOp) getOp(function, x, y, result); |  | ||||||
|         Nd4j.getExecutioner().exec(op); |  | ||||||
|         return op.z().getDouble(0); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Compute the distance between 2 vectors |  | ||||||
|      * given a function name. Valid function names: |  | ||||||
|      * euclidean: euclidean distance |  | ||||||
|      * cosinedistance: cosine distance |  | ||||||
|      * cosine similarity: cosine similarity |  | ||||||
|      * manhattan: manhattan distance |  | ||||||
|      * jaccard: jaccard distance |  | ||||||
|      * hamming: hamming distance |  | ||||||
|      * @param function the function to use (default euclidean distance) |  | ||||||
|      * @param x the first vector |  | ||||||
|      * @param y the second vector |  | ||||||
|      * @return the distance between the 2 vectors given the inputs |  | ||||||
|      */ |  | ||||||
|     public static double computeDistance(String function,INDArray x,INDArray y) { |  | ||||||
|         return computeDistance(function,x,y,Nd4j.scalar(0.0)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Initialize the tree given the input parameters |  | ||||||
|      * @param tree the tree to initialize |  | ||||||
|      * @param from the starting node |  | ||||||
|      * @param planes the hyper planes to use (vector space for similarity) |  | ||||||
|      * @param X the input data |  | ||||||
|      * @param maxSize the max number of indices on a given leaf node |  | ||||||
|      * @param depth the current depth of the tree |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      */ |  | ||||||
|     public static void buildTree(RPTree tree, |  | ||||||
|                                  RPNode from, |  | ||||||
|                                  RPHyperPlanes planes, |  | ||||||
|                                  INDArray X, |  | ||||||
|                                  int maxSize, |  | ||||||
|                                  int depth, |  | ||||||
|                                  String similarityFunction) { |  | ||||||
|         if(from.getIndices().size() <= maxSize) { |  | ||||||
|             //slimNode |  | ||||||
|             slimNode(from); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         List<Double> distances = new ArrayList<>(); |  | ||||||
|         RPNode left = new RPNode(tree,depth + 1); |  | ||||||
|         RPNode right = new RPNode(tree,depth + 1); |  | ||||||
| 
 |  | ||||||
|         if(planes.getWholeHyperPlane() == null || depth >= planes.getWholeHyperPlane().rows()) { |  | ||||||
|             planes.addRandomHyperPlane(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         INDArray hyperPlane = planes.getHyperPlaneAt(depth); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         for(int i = 0; i < from.getIndices().size(); i++) { |  | ||||||
|             double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); |  | ||||||
|             distances.add(cosineSim); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Collections.sort(distances); |  | ||||||
|         from.setMedian(distances.get(distances.size() / 2)); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         for(int i = 0; i < from.getIndices().size(); i++) { |  | ||||||
|             double cosineSim = computeDistance(similarityFunction,hyperPlane,X.slice(from.getIndices().get(i))); |  | ||||||
|             if(cosineSim <= from.getMedian()) { |  | ||||||
|                 left.getIndices().add(from.getIndices().get(i)); |  | ||||||
|             } |  | ||||||
|             else { |  | ||||||
|                 right.getIndices().add(from.getIndices().get(i)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         //failed split |  | ||||||
|         if(left.getIndices().isEmpty() || right.getIndices().isEmpty()) { |  | ||||||
|             slimNode(from); |  | ||||||
|             return; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         from.setLeft(left); |  | ||||||
|         from.setRight(right); |  | ||||||
|         slimNode(from); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         buildTree(tree,left,planes,X,maxSize,depth + 1,similarityFunction); |  | ||||||
|         buildTree(tree,right,planes,X,maxSize,depth + 1,similarityFunction); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Scan for leaves accumulating |  | ||||||
|      * the nodes in the passed in list |  | ||||||
|      * @param nodes the nodes so far |  | ||||||
|      * @param scan the tree to scan |  | ||||||
|      */ |  | ||||||
|     public static void scanForLeaves(List<RPNode> nodes,RPTree scan) { |  | ||||||
|         scanForLeaves(nodes,scan.getRoot()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Scan for leaves accumulating |  | ||||||
|      * the nodes in the passed in list |  | ||||||
|      * @param nodes the nodes so far |  | ||||||
|      */ |  | ||||||
|     public static void scanForLeaves(List<RPNode> nodes,RPNode current) { |  | ||||||
|         if(current.getLeft() == null && current.getRight() == null) |  | ||||||
|             nodes.add(current); |  | ||||||
|         if(current.getLeft() != null) |  | ||||||
|             scanForLeaves(nodes,current.getLeft()); |  | ||||||
|         if(current.getRight() != null) |  | ||||||
|             scanForLeaves(nodes,current.getRight()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Prune indices from the given node |  | ||||||
|      * when it's a leaf |  | ||||||
|      * @param node the node to prune |  | ||||||
|      */ |  | ||||||
|     public static void slimNode(RPNode node) { |  | ||||||
|         if(node.getRight() != null && node.getLeft() != null) { |  | ||||||
|             node.getIndices().clear(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,87 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.sptree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class Cell implements Serializable { |  | ||||||
|     private int dimension; |  | ||||||
|     private INDArray corner, width; |  | ||||||
| 
 |  | ||||||
|     public Cell(int dimension) { |  | ||||||
|         this.dimension = dimension; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double corner(int d) { |  | ||||||
|         return corner.getDouble(d); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double width(int d) { |  | ||||||
|         return width.getDouble(d); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setCorner(int d, double corner) { |  | ||||||
|         this.corner.putScalar(d, corner); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setWidth(int d, double width) { |  | ||||||
|         this.width.putScalar(d, width); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setWidth(INDArray width) { |  | ||||||
|         this.width = width; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setCorner(INDArray corner) { |  | ||||||
|         this.corner = corner; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public boolean contains(INDArray point) { |  | ||||||
|         INDArray cornerMinusWidth = corner.sub(width); |  | ||||||
|         INDArray cornerPlusWidth = corner.add(width); |  | ||||||
|         for (int d = 0; d < dimension; d++) { |  | ||||||
|             double pointD = point.getDouble(d); |  | ||||||
|             if (cornerMinusWidth.getDouble(d) > pointD) |  | ||||||
|                 return false; |  | ||||||
|             if (cornerPlusWidth.getDouble(d) < pointD) |  | ||||||
|                 return false; |  | ||||||
|         } |  | ||||||
|         return true; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray width() { |  | ||||||
|         return width; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray corner() { |  | ||||||
|         return corner; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,95 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.sptree; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.CosineSimilarity; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class DataPoint implements Serializable { |  | ||||||
|     private int index; |  | ||||||
|     private INDArray point; |  | ||||||
|     private long d; |  | ||||||
|     private String functionName; |  | ||||||
|     private boolean invert = false; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public DataPoint(int index, INDArray point, boolean invert) { |  | ||||||
|         this(index, point, "euclidean"); |  | ||||||
|         this.invert = invert; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public DataPoint(int index, INDArray point, String functionName, boolean invert) { |  | ||||||
|         this.index = index; |  | ||||||
|         this.point = point; |  | ||||||
|         this.functionName = functionName; |  | ||||||
|         this.d = point.length(); |  | ||||||
|         this.invert = invert; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public DataPoint(int index, INDArray point) { |  | ||||||
|         this(index, point, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public DataPoint(int index, INDArray point, String functionName) { |  | ||||||
|         this(index, point, functionName, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Euclidean distance |  | ||||||
|      * @param point the distance from this point to the given point |  | ||||||
|      * @return the distance between the two points |  | ||||||
|      */ |  | ||||||
|     public float distance(DataPoint point) { |  | ||||||
|         switch (functionName) { |  | ||||||
|             case "euclidean": |  | ||||||
|                 float ret = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) |  | ||||||
|                                 .getFinalResult().floatValue(); |  | ||||||
|                 return invert ? -ret : ret; |  | ||||||
| 
 |  | ||||||
|             case "cosinesimilarity": |  | ||||||
|                 float ret2 = Nd4j.getExecutioner().execAndReturn(new CosineSimilarity(this.point, point.point)) |  | ||||||
|                                 .getFinalResult().floatValue(); |  | ||||||
|                 return invert ? -ret2 : ret2; |  | ||||||
| 
 |  | ||||||
|             case "manhattan": |  | ||||||
|                 float ret3 = Nd4j.getExecutioner().execAndReturn(new ManhattanDistance(this.point, point.point)) |  | ||||||
|                                 .getFinalResult().floatValue(); |  | ||||||
|                 return invert ? -ret3 : ret3; |  | ||||||
|             case "dot": |  | ||||||
|                 float dotRet = (float) Nd4j.getBlasWrapper().dot(this.point, point.point); |  | ||||||
|                 return invert ? -dotRet : dotRet; |  | ||||||
|             default: |  | ||||||
|                 float ret4 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(this.point, point.point)) |  | ||||||
|                                 .getFinalResult().floatValue(); |  | ||||||
|                 return invert ? -ret4 : ret4; |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,83 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.sptree; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class HeapItem implements Serializable, Comparable<HeapItem> { |  | ||||||
|     private int index; |  | ||||||
|     private double distance; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public HeapItem(int index, double distance) { |  | ||||||
|         this.index = index; |  | ||||||
|         this.distance = distance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getIndex() { |  | ||||||
|         return index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setIndex(int index) { |  | ||||||
|         this.index = index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public double getDistance() { |  | ||||||
|         return distance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setDistance(double distance) { |  | ||||||
|         this.distance = distance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean equals(Object o) { |  | ||||||
|         if (this == o) |  | ||||||
|             return true; |  | ||||||
|         if (o == null || getClass() != o.getClass()) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         HeapItem heapItem = (HeapItem) o; |  | ||||||
| 
 |  | ||||||
|         if (index != heapItem.index) |  | ||||||
|             return false; |  | ||||||
|         return Double.compare(heapItem.distance, distance) == 0; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int hashCode() { |  | ||||||
|         int result; |  | ||||||
|         long temp; |  | ||||||
|         result = index; |  | ||||||
|         temp = Double.doubleToLongBits(distance); |  | ||||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int compareTo(HeapItem o) { |  | ||||||
|         return distance < o.distance ? 1 : 0; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,72 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.sptree; |  | ||||||
| 
 |  | ||||||
| import lombok.Data; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @Data |  | ||||||
| public class HeapObject implements Serializable, Comparable<HeapObject> { |  | ||||||
|     private int index; |  | ||||||
|     private INDArray point; |  | ||||||
|     private double distance; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public HeapObject(int index, INDArray point, double distance) { |  | ||||||
|         this.index = index; |  | ||||||
|         this.point = point; |  | ||||||
|         this.distance = distance; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public boolean equals(Object o) { |  | ||||||
|         if (this == o) |  | ||||||
|             return true; |  | ||||||
|         if (o == null || getClass() != o.getClass()) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         HeapObject heapObject = (HeapObject) o; |  | ||||||
| 
 |  | ||||||
|         if (!point.equals(heapObject.point)) |  | ||||||
|             return false; |  | ||||||
| 
 |  | ||||||
|         return Double.compare(heapObject.distance, distance) == 0; |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int hashCode() { |  | ||||||
|         int result; |  | ||||||
|         long temp; |  | ||||||
|         result = index; |  | ||||||
|         temp = Double.doubleToLongBits(distance); |  | ||||||
|         result = 31 * result + (int) (temp ^ (temp >>> 32)); |  | ||||||
|         return result; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public int compareTo(HeapObject o) { |  | ||||||
|         return distance < o.distance ? 1 : 0; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,425 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.sptree; |  | ||||||
| 
 |  | ||||||
| import org.nd4j.shade.guava.util.concurrent.AtomicDouble; |  | ||||||
| import lombok.val; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.nn.conf.WorkspaceMode; |  | ||||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.custom.BarnesEdgeForces; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.linalg.api.memory.abstracts.DummyWorkspace; |  | ||||||
| import org.slf4j.Logger; |  | ||||||
| import org.slf4j.LoggerFactory; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Collection; |  | ||||||
| import java.util.Set; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * @author Adam Gibson |  | ||||||
|  */ |  | ||||||
| public class SpTree implements Serializable { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public final static String workspaceExternal = "SPTREE_LOOP_EXTERNAL"; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private int D; |  | ||||||
|     private INDArray data; |  | ||||||
|     public final static int NODE_RATIO = 8000; |  | ||||||
|     private int N; |  | ||||||
|     private int size; |  | ||||||
|     private int cumSize; |  | ||||||
|     private Cell boundary; |  | ||||||
|     private INDArray centerOfMass; |  | ||||||
|     private SpTree parent; |  | ||||||
|     private int[] index; |  | ||||||
|     private int nodeCapacity; |  | ||||||
|     private int numChildren = 2; |  | ||||||
|     private boolean isLeaf = true; |  | ||||||
|     private Collection<INDArray> indices; |  | ||||||
|     private SpTree[] children; |  | ||||||
|     private static Logger log = LoggerFactory.getLogger(SpTree.class); |  | ||||||
|     private String similarityFunction = Distance.EUCLIDEAN.toString(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, |  | ||||||
|                   String similarityFunction) { |  | ||||||
|         init(parent, data, corner, width, indices, similarityFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree(INDArray data, Collection<INDArray> indices, String similarityFunction) { |  | ||||||
|         this.indices = indices; |  | ||||||
|         this.N = data.rows(); |  | ||||||
|         this.D = data.columns(); |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         data = data.dup(); |  | ||||||
|         INDArray meanY = data.mean(0); |  | ||||||
|         INDArray minY = data.min(0); |  | ||||||
|         INDArray maxY = data.max(0); |  | ||||||
|         INDArray width = Nd4j.create(data.dataType(), meanY.shape()); |  | ||||||
|         for (int i = 0; i < width.length(); i++) { |  | ||||||
|             width.putScalar(i, Math.max(maxY.getDouble(i) - meanY.getDouble(i), |  | ||||||
|                     meanY.getDouble(i) - minY.getDouble(i)) + Nd4j.EPS_THRESHOLD); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         try(MemoryWorkspace ws = Nd4j.getMemoryManager().scopeOutOfWorkspaces()) { |  | ||||||
|             init(null, data, meanY, width, indices, similarityFunction); |  | ||||||
|             fill(N); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices) { |  | ||||||
|         this(parent, data, corner, width, indices, "euclidean"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree(INDArray data, Collection<INDArray> indices) { |  | ||||||
|         this(data, indices, "euclidean"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree(INDArray data) { |  | ||||||
|         this(data, new ArrayList<INDArray>()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public MemoryWorkspace workspace() { |  | ||||||
|         return null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void init(SpTree parent, INDArray data, INDArray corner, INDArray width, Collection<INDArray> indices, |  | ||||||
|                       String similarityFunction) { |  | ||||||
| 
 |  | ||||||
|         this.parent = parent; |  | ||||||
|         D = data.columns(); |  | ||||||
|         N = data.rows(); |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         nodeCapacity = N % NODE_RATIO; |  | ||||||
|         index = new int[nodeCapacity]; |  | ||||||
|         for (int d = 1; d < this.D; d++) |  | ||||||
|             numChildren *= 2; |  | ||||||
|         this.indices = indices; |  | ||||||
|         isLeaf = true; |  | ||||||
|         size = 0; |  | ||||||
|         cumSize = 0; |  | ||||||
|         children = new SpTree[numChildren]; |  | ||||||
|         this.data = data; |  | ||||||
|         boundary = new Cell(D); |  | ||||||
|         boundary.setCorner(corner.dup()); |  | ||||||
|         boundary.setWidth(width.dup()); |  | ||||||
|         centerOfMass = Nd4j.create(data.dataType(), D); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     private boolean insert(int index) { |  | ||||||
|         /*MemoryWorkspace workspace = |  | ||||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() |  | ||||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( |  | ||||||
|                         workspaceConfigurationExternal, |  | ||||||
|                         workspaceExternal); |  | ||||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { |  | ||||||
| 
 |  | ||||||
|             INDArray point = data.slice(index); |  | ||||||
|             /*boolean contains = false; |  | ||||||
|             SpTreeCell op = new SpTreeCell(boundary.corner(), boundary.width(), point, N, contains); |  | ||||||
|             Nd4j.getExecutioner().exec(op); |  | ||||||
|             op.getOutputArgument(0).getScalar(0); |  | ||||||
|             if (!contains) return false;*/ |  | ||||||
|             if (!boundary.contains(point)) |  | ||||||
|                 return false; |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             cumSize++; |  | ||||||
|             double mult1 = (double) (cumSize - 1) / (double) cumSize; |  | ||||||
|             double mult2 = 1.0 / (double) cumSize; |  | ||||||
|             centerOfMass.muli(mult1); |  | ||||||
|             centerOfMass.addi(point.mul(mult2)); |  | ||||||
|             // If there is space in this quad tree and it is a leaf, add the object here |  | ||||||
|             if (isLeaf() && size < nodeCapacity) { |  | ||||||
|                 this.index[size] = index; |  | ||||||
|                 indices.add(point); |  | ||||||
|                 size++; |  | ||||||
|                 return true; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             for (int i = 0; i < size; i++) { |  | ||||||
|                 INDArray compPoint = data.slice(this.index[i]); |  | ||||||
|                 if (compPoint.equals(point)) |  | ||||||
|                     return true; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             if (isLeaf()) |  | ||||||
|                 subDivide(); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|             // Find out where the point can be inserted |  | ||||||
|             for (int i = 0; i < numChildren; i++) { |  | ||||||
|                 if (children[i].insert(index)) |  | ||||||
|                     return true; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             throw new IllegalStateException("Shouldn't reach this state"); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Subdivide the node in to |  | ||||||
|      * 4 children |  | ||||||
|      */ |  | ||||||
|     public void subDivide() { |  | ||||||
|         /*MemoryWorkspace workspace = |  | ||||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() |  | ||||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( |  | ||||||
|                         workspaceConfigurationExternal, |  | ||||||
|                         workspaceExternal); |  | ||||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered()) */{ |  | ||||||
| 
 |  | ||||||
|             INDArray newCorner = Nd4j.create(data.dataType(), D); |  | ||||||
|             INDArray newWidth = Nd4j.create(data.dataType(), D); |  | ||||||
|             for (int i = 0; i < numChildren; i++) { |  | ||||||
|                 int div = 1; |  | ||||||
|                 for (int d = 0; d < D; d++) { |  | ||||||
|                     newWidth.putScalar(d, .5 * boundary.width(d)); |  | ||||||
|                     if ((i / div) % 2 == 1) |  | ||||||
|                         newCorner.putScalar(d, boundary.corner(d) - .5 * boundary.width(d)); |  | ||||||
|                     else |  | ||||||
|                         newCorner.putScalar(d, boundary.corner(d) + .5 * boundary.width(d)); |  | ||||||
|                     div *= 2; |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 children[i] = new SpTree(this, data, newCorner, newWidth, indices); |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // Move existing points to correct children |  | ||||||
|             for (int i = 0; i < size; i++) { |  | ||||||
|                 boolean success = false; |  | ||||||
|                 for (int j = 0; j < this.numChildren; j++) |  | ||||||
|                     if (!success) |  | ||||||
|                         success = children[j].insert(index[i]); |  | ||||||
| 
 |  | ||||||
|                 index[i] = -1; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             // Empty parent node |  | ||||||
|             size = 0; |  | ||||||
|             isLeaf = false; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Compute non edge forces using barnes hut |  | ||||||
|      * @param pointIndex |  | ||||||
|      * @param theta |  | ||||||
|      * @param negativeForce |  | ||||||
|      * @param sumQ |  | ||||||
|      */ |  | ||||||
|     public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) { |  | ||||||
|         // Make sure that we spend no time on empty nodes or self-interactions |  | ||||||
|         INDArray buf = Nd4j.create(data.dataType(), this.D); |  | ||||||
| 
 |  | ||||||
|         if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex)) |  | ||||||
|             return; |  | ||||||
|        /* MemoryWorkspace workspace = |  | ||||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() |  | ||||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( |  | ||||||
|                         workspaceConfigurationExternal, |  | ||||||
|                         workspaceExternal); |  | ||||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { |  | ||||||
| 
 |  | ||||||
|             // Compute distance between point and center-of-mass |  | ||||||
|             data.slice(pointIndex).subi(centerOfMass, buf); |  | ||||||
| 
 |  | ||||||
|             double D = Nd4j.getBlasWrapper().dot(buf, buf); |  | ||||||
|             // Check whether we can use this node as a "summary" |  | ||||||
|             double maxWidth = boundary.width().maxNumber().doubleValue(); |  | ||||||
|             // Check whether we can use this node as a "summary" |  | ||||||
|             if (isLeaf() || maxWidth / Math.sqrt(D) < theta) { |  | ||||||
| 
 |  | ||||||
|                 // Compute and add t-SNE force between point and current node |  | ||||||
|                 double Q = 1.0 / (1.0 + D); |  | ||||||
|                 double mult = cumSize * Q; |  | ||||||
|                 sumQ.addAndGet(mult); |  | ||||||
|                 mult *= Q; |  | ||||||
|                 negativeForce.addi(buf.mul(mult)); |  | ||||||
|             } else { |  | ||||||
| 
 |  | ||||||
|                 // Recursively apply Barnes-Hut to children |  | ||||||
|                 for (int i = 0; i < numChildren; i++) { |  | ||||||
|                     children[i].computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * Compute edge forces using barnes hut |  | ||||||
|      * @param rowP a vector |  | ||||||
|      * @param colP |  | ||||||
|      * @param valP |  | ||||||
|      * @param N the number of elements |  | ||||||
|      * @param posF the positive force |  | ||||||
|      */ |  | ||||||
|     public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) { |  | ||||||
|         if (!rowP.isVector()) |  | ||||||
|             throw new IllegalArgumentException("RowP must be a vector"); |  | ||||||
| 
 |  | ||||||
|         // Loop over all edges in the graph |  | ||||||
|         // just execute native op |  | ||||||
|         Nd4j.exec(new BarnesEdgeForces(rowP, colP, valP, data, N, posF)); |  | ||||||
| 
 |  | ||||||
|         /* |  | ||||||
|         INDArray buf = Nd4j.create(data.dataType(), this.D); |  | ||||||
|         double D; |  | ||||||
|         for (int n = 0; n < N; n++) { |  | ||||||
|             INDArray slice = data.slice(n); |  | ||||||
|             for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) { |  | ||||||
| 
 |  | ||||||
|                 // Compute pairwise distance and Q-value |  | ||||||
|                 slice.subi(data.slice(colP.getInt(i)), buf); |  | ||||||
| 
 |  | ||||||
|                 D = 1.0 + Nd4j.getBlasWrapper().dot(buf, buf); |  | ||||||
|                 D = valP.getDouble(i) / D; |  | ||||||
| 
 |  | ||||||
|                 // Sum positive force |  | ||||||
|                 posF.slice(n).addi(buf.muli(D)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|         */ |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public boolean isLeaf() { |  | ||||||
|         return isLeaf; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Verifies the structure of the tree (does bounds checking on each node) |  | ||||||
|      * @return true if the structure of the tree |  | ||||||
|      * is correct. |  | ||||||
|      */ |  | ||||||
|     public boolean isCorrect() { |  | ||||||
|         /*MemoryWorkspace workspace = |  | ||||||
|                 workspaceMode == WorkspaceMode.NONE ? new DummyWorkspace() |  | ||||||
|                         : Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread( |  | ||||||
|                         workspaceConfigurationExternal, |  | ||||||
|                         workspaceExternal); |  | ||||||
|         try (MemoryWorkspace ws = workspace.notifyScopeEntered())*/ { |  | ||||||
| 
 |  | ||||||
|             for (int n = 0; n < size; n++) { |  | ||||||
|                 INDArray point = data.slice(index[n]); |  | ||||||
|                 if (!boundary.contains(point)) |  | ||||||
|                     return false; |  | ||||||
|             } |  | ||||||
|             if (!isLeaf()) { |  | ||||||
|                 boolean correct = true; |  | ||||||
|                 for (int i = 0; i < numChildren; i++) |  | ||||||
|                     correct = correct && children[i].isCorrect(); |  | ||||||
|                 return correct; |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             return true; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * The depth of the node |  | ||||||
|      * @return the depth of the node |  | ||||||
|      */ |  | ||||||
|     public int depth() { |  | ||||||
|         if (isLeaf()) |  | ||||||
|             return 1; |  | ||||||
|         int depth = 1; |  | ||||||
|         int maxChildDepth = 0; |  | ||||||
|         for (int i = 0; i < numChildren; i++) { |  | ||||||
|             maxChildDepth = Math.max(maxChildDepth, children[0].depth()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return depth + maxChildDepth; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private void fill(int n) { |  | ||||||
|         if (indices.isEmpty() && parent == null) |  | ||||||
|             for (int i = 0; i < n; i++) { |  | ||||||
|                 log.trace("Inserted " + i); |  | ||||||
|                 insert(i); |  | ||||||
|             } |  | ||||||
|         else |  | ||||||
|             log.warn("Called fill already"); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public SpTree[] getChildren() { |  | ||||||
|         return children; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getD() { |  | ||||||
|         return D; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public INDArray getCenterOfMass() { |  | ||||||
|         return centerOfMass; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public Cell getBoundary() { |  | ||||||
|         return boundary; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int[] getIndex() { |  | ||||||
|         return index; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getCumSize() { |  | ||||||
|         return cumSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setCumSize(int cumSize) { |  | ||||||
|         this.cumSize = cumSize; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public int getNumChildren() { |  | ||||||
|         return numChildren; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void setNumChildren(int numChildren) { |  | ||||||
|         this.numChildren = numChildren; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,117 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.strategy; |  | ||||||
| 
 |  | ||||||
| import lombok.*; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; |  | ||||||
| import org.deeplearning4j.clustering.condition.ConvergenceCondition; |  | ||||||
| import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| 
 |  | ||||||
| @AllArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public abstract class BaseClusteringStrategy implements ClusteringStrategy, Serializable { |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected ClusteringStrategyType type; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected Integer initialClusterCount; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected ClusteringAlgorithmCondition optimizationPhaseCondition; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected ClusteringAlgorithmCondition terminationCondition; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected boolean inverse; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected Distance distanceFunction; |  | ||||||
|     @Getter(AccessLevel.PUBLIC) |  | ||||||
|     @Setter(AccessLevel.PROTECTED) |  | ||||||
|     protected boolean allowEmptyClusters; |  | ||||||
| 
 |  | ||||||
|     public BaseClusteringStrategy(ClusteringStrategyType type, Integer initialClusterCount, Distance distanceFunction, |  | ||||||
|                     boolean allowEmptyClusters, boolean inverse) { |  | ||||||
|         this.type = type; |  | ||||||
|         this.initialClusterCount = initialClusterCount; |  | ||||||
|         this.distanceFunction = distanceFunction; |  | ||||||
|         this.allowEmptyClusters = allowEmptyClusters; |  | ||||||
|         this.inverse = inverse; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public BaseClusteringStrategy(ClusteringStrategyType clusteringStrategyType, int initialClusterCount, |  | ||||||
|                     Distance distanceFunction, boolean inverse) { |  | ||||||
|         this(clusteringStrategyType, initialClusterCount, distanceFunction, false, inverse); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param maxIterationCount |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount) { |  | ||||||
|         setTerminationCondition(FixedIterationCountCondition.iterationCountGreaterThan(maxIterationCount)); |  | ||||||
|         return this; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param rate |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate) { |  | ||||||
|         setTerminationCondition(ConvergenceCondition.distributionVariationRateLessThan(rate)); |  | ||||||
|         return this; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public boolean inverseDistanceCalculation() { |  | ||||||
|         return inverse; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param type |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public boolean isStrategyOfType(ClusteringStrategyType type) { |  | ||||||
|         return type.equals(this.type); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public Integer getInitialClusterCount() { |  | ||||||
|         return initialClusterCount; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,102 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.strategy; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| public interface ClusteringStrategy { |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean inverseDistanceCalculation(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     ClusteringStrategyType getType(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param type |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean isStrategyOfType(ClusteringStrategyType type); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Integer getInitialClusterCount(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     Distance getDistanceFunction(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean isAllowEmptyClusters(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     ClusteringAlgorithmCondition getTerminationCondition(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean isOptimizationDefined(); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param iterationHistory |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     boolean isOptimizationApplicableNow(IterationHistory iterationHistory); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param maxIterationCount |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     BaseClusteringStrategy endWhenIterationCountEquals(int maxIterationCount); |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param rate |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     BaseClusteringStrategy endWhenDistributionVariationRateLessThan(double rate); |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,25 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.strategy; |  | ||||||
| 
 |  | ||||||
| public enum ClusteringStrategyType { |  | ||||||
|     FIXED_CLUSTER_COUNT, OPTIMIZATION |  | ||||||
| } |  | ||||||
| @ -1,68 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.strategy; |  | ||||||
| 
 |  | ||||||
| import lombok.AccessLevel; |  | ||||||
| import lombok.NoArgsConstructor; |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| 
 |  | ||||||
| /** |  | ||||||
|  * |  | ||||||
|  */ |  | ||||||
| @NoArgsConstructor(access = AccessLevel.PROTECTED) |  | ||||||
| public class FixedClusterCountStrategy extends BaseClusteringStrategy { |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected FixedClusterCountStrategy(Integer initialClusterCount, Distance distanceFunction, |  | ||||||
|                     boolean allowEmptyClusters, boolean inverse) { |  | ||||||
|         super(ClusteringStrategyType.FIXED_CLUSTER_COUNT, initialClusterCount, distanceFunction, allowEmptyClusters, |  | ||||||
|                         inverse); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param clusterCount |  | ||||||
|      * @param distanceFunction |  | ||||||
|      * @param inverse |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static FixedClusterCountStrategy setup(int clusterCount, Distance distanceFunction, boolean inverse) { |  | ||||||
|         return new FixedClusterCountStrategy(clusterCount, distanceFunction, false, inverse); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     @Override |  | ||||||
|     public boolean inverseDistanceCalculation() { |  | ||||||
|         return inverse; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isOptimizationDefined() { |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,82 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.strategy; |  | ||||||
| 
 |  | ||||||
| import org.deeplearning4j.clustering.algorithm.Distance; |  | ||||||
| import org.deeplearning4j.clustering.condition.ClusteringAlgorithmCondition; |  | ||||||
| import org.deeplearning4j.clustering.condition.ConvergenceCondition; |  | ||||||
| import org.deeplearning4j.clustering.condition.FixedIterationCountCondition; |  | ||||||
| import org.deeplearning4j.clustering.iteration.IterationHistory; |  | ||||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimization; |  | ||||||
| import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType; |  | ||||||
| 
 |  | ||||||
| public class OptimisationStrategy extends BaseClusteringStrategy { |  | ||||||
|     public static int defaultIterationCount = 100; |  | ||||||
| 
 |  | ||||||
|     private ClusteringOptimization clusteringOptimisation; |  | ||||||
|     private ClusteringAlgorithmCondition clusteringOptimisationApplicationCondition; |  | ||||||
| 
 |  | ||||||
|     protected OptimisationStrategy() { |  | ||||||
|         super(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected OptimisationStrategy(int initialClusterCount, Distance distanceFunction) { |  | ||||||
|         super(ClusteringStrategyType.OPTIMIZATION, initialClusterCount, distanceFunction, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static OptimisationStrategy setup(int initialClusterCount, Distance distanceFunction) { |  | ||||||
|         return new OptimisationStrategy(initialClusterCount, distanceFunction); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public OptimisationStrategy optimize(ClusteringOptimizationType type, double value) { |  | ||||||
|         clusteringOptimisation = new ClusteringOptimization(type, value); |  | ||||||
|         return this; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public OptimisationStrategy optimizeWhenIterationCountMultipleOf(int value) { |  | ||||||
|         clusteringOptimisationApplicationCondition = FixedIterationCountCondition.iterationCountGreaterThan(value); |  | ||||||
|         return this; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public OptimisationStrategy optimizeWhenPointDistributionVariationRateLessThan(double rate) { |  | ||||||
|         clusteringOptimisationApplicationCondition = ConvergenceCondition.distributionVariationRateLessThan(rate); |  | ||||||
|         return this; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     public double getClusteringOptimizationValue() { |  | ||||||
|         return clusteringOptimisation.getValue(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isClusteringOptimizationType(ClusteringOptimizationType type) { |  | ||||||
|         return clusteringOptimisation != null && clusteringOptimisation.getType().equals(type); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isOptimizationDefined() { |  | ||||||
|         return clusteringOptimisation != null; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public boolean isOptimizationApplicableNow(IterationHistory iterationHistory) { |  | ||||||
|         return clusteringOptimisationApplicationCondition != null |  | ||||||
|                         && clusteringOptimisationApplicationCondition.isSatisfied(iterationHistory); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							| @ -1,74 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.util; |  | ||||||
| 
 |  | ||||||
| import org.slf4j.Logger; |  | ||||||
| import org.slf4j.LoggerFactory; |  | ||||||
| 
 |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.concurrent.*; |  | ||||||
| 
 |  | ||||||
| public class MultiThreadUtils { |  | ||||||
| 
 |  | ||||||
|     private static Logger log = LoggerFactory.getLogger(MultiThreadUtils.class); |  | ||||||
| 
 |  | ||||||
|     private static ExecutorService instance; |  | ||||||
| 
 |  | ||||||
|     private MultiThreadUtils() {} |  | ||||||
| 
 |  | ||||||
|     public static synchronized ExecutorService newExecutorService() { |  | ||||||
|         int nThreads = Runtime.getRuntime().availableProcessors(); |  | ||||||
|         return new ThreadPoolExecutor(nThreads, nThreads, 60L, TimeUnit.SECONDS, new LinkedTransferQueue<Runnable>(), |  | ||||||
|                         new ThreadFactory() { |  | ||||||
|                             @Override |  | ||||||
|                             public Thread newThread(Runnable r) { |  | ||||||
|                                 Thread t = Executors.defaultThreadFactory().newThread(r); |  | ||||||
|                                 t.setDaemon(true); |  | ||||||
|                                 return t; |  | ||||||
|                             } |  | ||||||
|                         }); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static void parallelTasks(final List<Runnable> tasks, ExecutorService executorService) { |  | ||||||
|         int tasksCount = tasks.size(); |  | ||||||
|         final CountDownLatch latch = new CountDownLatch(tasksCount); |  | ||||||
|         for (int i = 0; i < tasksCount; i++) { |  | ||||||
|             final int taskIdx = i; |  | ||||||
|             executorService.execute(new Runnable() { |  | ||||||
|                 public void run() { |  | ||||||
|                     try { |  | ||||||
|                         tasks.get(taskIdx).run(); |  | ||||||
|                     } catch (Throwable e) { |  | ||||||
|                         log.info("Unchecked exception thrown by task", e); |  | ||||||
|                     } finally { |  | ||||||
|                         latch.countDown(); |  | ||||||
|                     } |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         try { |  | ||||||
|             latch.await(); |  | ||||||
|         } catch (Exception e) { |  | ||||||
|             throw new RuntimeException(e); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,61 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.util; |  | ||||||
| 
 |  | ||||||
| import java.util.Collection; |  | ||||||
| import java.util.HashSet; |  | ||||||
| import java.util.Set; |  | ||||||
| 
 |  | ||||||
| public class SetUtils { |  | ||||||
|     private SetUtils() {} |  | ||||||
| 
 |  | ||||||
|     // Set specific operations |  | ||||||
| 
 |  | ||||||
|     public static <T> Set<T> intersection(Collection<T> parentCollection, Collection<T> removeFromCollection) { |  | ||||||
|         Set<T> results = new HashSet<>(parentCollection); |  | ||||||
|         results.retainAll(removeFromCollection); |  | ||||||
|         return results; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static <T> boolean intersectionP(Set<? extends T> s1, Set<? extends T> s2) { |  | ||||||
|         for (T elt : s1) { |  | ||||||
|             if (s2.contains(elt)) |  | ||||||
|                 return true; |  | ||||||
|         } |  | ||||||
|         return false; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public static <T> Set<T> union(Set<? extends T> s1, Set<? extends T> s2) { |  | ||||||
|         Set<T> s3 = new HashSet<>(s1); |  | ||||||
|         s3.addAll(s2); |  | ||||||
|         return s3; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** Return is s1 \ s2 */ |  | ||||||
| 
 |  | ||||||
|     public static <T> Set<T> difference(Collection<? extends T> s1, Collection<? extends T> s2) { |  | ||||||
|         Set<T> s3 = new HashSet<>(s1); |  | ||||||
|         s3.removeAll(s2); |  | ||||||
|         return s3; |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @ -1,633 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.vptree; |  | ||||||
| 
 |  | ||||||
| import lombok.*; |  | ||||||
| import lombok.extern.slf4j.Slf4j; |  | ||||||
| import org.deeplearning4j.clustering.sptree.DataPoint; |  | ||||||
| import org.deeplearning4j.clustering.sptree.HeapObject; |  | ||||||
| import org.deeplearning4j.clustering.util.MathUtils; |  | ||||||
| import org.nd4j.linalg.api.memory.MemoryWorkspace; |  | ||||||
| import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration; |  | ||||||
| import org.nd4j.linalg.api.memory.enums.*; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.api.ops.impl.reduce3.*; |  | ||||||
| import org.nd4j.linalg.exception.ND4JIllegalStateException; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.io.Serializable; |  | ||||||
| import java.util.*; |  | ||||||
| import java.util.concurrent.*; |  | ||||||
| import java.util.concurrent.atomic.AtomicInteger; |  | ||||||
| 
 |  | ||||||
| @Slf4j |  | ||||||
| @Builder |  | ||||||
| @AllArgsConstructor |  | ||||||
| public class VPTree implements Serializable { |  | ||||||
|     private static final long serialVersionUID = 1L; |  | ||||||
| 
 |  | ||||||
|     public static final String EUCLIDEAN = "euclidean"; |  | ||||||
|     private double tau; |  | ||||||
|     @Getter |  | ||||||
|     @Setter |  | ||||||
|     private INDArray items; |  | ||||||
|     private List<INDArray> itemsList; |  | ||||||
|     private Node root; |  | ||||||
|     private String similarityFunction; |  | ||||||
|     @Getter |  | ||||||
|     private boolean invert = false; |  | ||||||
|     private transient ExecutorService executorService; |  | ||||||
|     @Getter |  | ||||||
|     private int workers = 1; |  | ||||||
|     private AtomicInteger size = new AtomicInteger(0); |  | ||||||
| 
 |  | ||||||
|     private transient ThreadLocal<INDArray> scalars = new ThreadLocal<>(); |  | ||||||
| 
 |  | ||||||
|     private WorkspaceConfiguration workspaceConfiguration; |  | ||||||
| 
 |  | ||||||
|     protected VPTree() { |  | ||||||
|         // method for serialization only |  | ||||||
|         scalars = new ThreadLocal<>(); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param points |  | ||||||
|      * @param invert |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray points, boolean invert) { |  | ||||||
|         this(points, "euclidean", 1, invert); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param points |  | ||||||
|      * @param invert |  | ||||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray points, boolean invert, int workers) { |  | ||||||
|         this(points, "euclidean", workers, invert); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items the items to use |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @param invert whether to invert the distance (similarity functions have different min/max objectives) |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray items, String similarityFunction, boolean invert) { |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         this.invert = invert; |  | ||||||
|         this.items = items; |  | ||||||
|         root = buildFromPoints(items); |  | ||||||
|         workers = 1; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items the items to use |  | ||||||
|      * @param similarityFunction the similarity function to use |  | ||||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) |  | ||||||
|      * @param invert whether to invert the metric (different optimization objective) |  | ||||||
|      */ |  | ||||||
|     public VPTree(List<DataPoint> items, String similarityFunction, int workers, boolean invert) { |  | ||||||
|         this.workers = workers; |  | ||||||
| 
 |  | ||||||
|         val list = new INDArray[items.size()]; |  | ||||||
| 
 |  | ||||||
|         // build list of INDArrays first |  | ||||||
|         for (int i = 0; i < items.size(); i++) |  | ||||||
|             list[i] = items.get(i).getPoint(); |  | ||||||
|             //this.items.putRow(i, items.get(i).getPoint()); |  | ||||||
| 
 |  | ||||||
|         // just stack them out with concat :) |  | ||||||
|         this.items = Nd4j.pile(list); |  | ||||||
| 
 |  | ||||||
|         this.invert = invert; |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         root = buildFromPoints(this.items); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items |  | ||||||
|      * @param similarityFunction |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray items, String similarityFunction) { |  | ||||||
|         this(items, similarityFunction, 1, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items |  | ||||||
|      * @param similarityFunction |  | ||||||
|      * @param workers number of parallel workers for tree building (increases memory requirements!) |  | ||||||
|      * @param invert |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray items, String similarityFunction, int workers, boolean invert) { |  | ||||||
|         this.similarityFunction = similarityFunction; |  | ||||||
|         this.invert = invert; |  | ||||||
|         this.items = items; |  | ||||||
| 
 |  | ||||||
|         this.workers = workers; |  | ||||||
|         root = buildFromPoints(items); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items |  | ||||||
|      * @param similarityFunction |  | ||||||
|      */ |  | ||||||
|     public VPTree(List<DataPoint> items, String similarityFunction) { |  | ||||||
|         this(items, similarityFunction, 1, false); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items |  | ||||||
|      */ |  | ||||||
|     public VPTree(INDArray items) { |  | ||||||
|         this(items, EUCLIDEAN); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param items |  | ||||||
|      */ |  | ||||||
|     public VPTree(List<DataPoint> items) { |  | ||||||
|         this(items, EUCLIDEAN); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Create an ndarray |  | ||||||
|      * from the datapoints |  | ||||||
|      * @param data |  | ||||||
|      * @return |  | ||||||
|      */ |  | ||||||
|     public static INDArray buildFromData(List<DataPoint> data) { |  | ||||||
|         INDArray ret = Nd4j.create(data.size(), data.get(0).getD()); |  | ||||||
|         for (int i = 0; i < ret.slices(); i++) |  | ||||||
|             ret.putSlice(i, data.get(i).getPoint()); |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param basePoint |  | ||||||
|      * @param distancesArr |  | ||||||
|      */ |  | ||||||
|     public void calcDistancesRelativeTo(INDArray items, INDArray basePoint, INDArray distancesArr) { |  | ||||||
|         switch (similarityFunction) { |  | ||||||
|             case "euclidean": |  | ||||||
|                 Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true,-1)); |  | ||||||
|                 break; |  | ||||||
|             case "cosinedistance": |  | ||||||
|                 Nd4j.getExecutioner().exec(new CosineDistance(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
|             case "cosinesimilarity": |  | ||||||
|                 Nd4j.getExecutioner().exec(new CosineSimilarity(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
|             case "manhattan": |  | ||||||
|                 Nd4j.getExecutioner().exec(new ManhattanDistance(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
|             case "dot": |  | ||||||
|                 Nd4j.getExecutioner().exec(new Dot(items, basePoint, distancesArr, -1)); |  | ||||||
|                 break; |  | ||||||
|             case "jaccard": |  | ||||||
|                 Nd4j.getExecutioner().exec(new JaccardDistance(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
|             case "hamming": |  | ||||||
|                 Nd4j.getExecutioner().exec(new HammingDistance(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
|             default: |  | ||||||
|                 Nd4j.getExecutioner().exec(new EuclideanDistance(items, basePoint, distancesArr, true, -1)); |  | ||||||
|                 break; |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         if (invert) |  | ||||||
|             distancesArr.negi(); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void calcDistancesRelativeTo(INDArray basePoint, INDArray distancesArr) { |  | ||||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * Euclidean distance |  | ||||||
|      * @return the distance between the two points |  | ||||||
|      */ |  | ||||||
|     public double distance(INDArray arr1, INDArray arr2) { |  | ||||||
|         if (scalars == null) |  | ||||||
|             scalars = new ThreadLocal<>(); |  | ||||||
| 
 |  | ||||||
|         if (scalars.get() == null) |  | ||||||
|             scalars.set(Nd4j.scalar(arr1.dataType(), 0.0)); |  | ||||||
| 
 |  | ||||||
|         switch (similarityFunction) { |  | ||||||
|             case "jaccard": |  | ||||||
|                 double ret7 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new JaccardDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret7 : ret7; |  | ||||||
|             case "hamming": |  | ||||||
|                 double ret8 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new HammingDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret8 : ret8; |  | ||||||
|             case "euclidean": |  | ||||||
|                 double ret = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret : ret; |  | ||||||
|             case "cosinesimilarity": |  | ||||||
|                 double ret2 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new CosineSimilarity(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret2 : ret2; |  | ||||||
|             case "cosinedistance": |  | ||||||
|                 double ret6 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new CosineDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret6 : ret6; |  | ||||||
|             case "manhattan": |  | ||||||
|                 double ret3 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new ManhattanDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret3 : ret3; |  | ||||||
|             case "dot": |  | ||||||
|                 double dotRet = Nd4j.getBlasWrapper().dot(arr1, arr2); |  | ||||||
|                 return invert ? -dotRet : dotRet; |  | ||||||
|             default: |  | ||||||
|                 double ret4 = Nd4j.getExecutioner() |  | ||||||
|                         .execAndReturn(new EuclideanDistance(arr1, arr2, scalars.get())) |  | ||||||
|                         .getFinalResult().doubleValue(); |  | ||||||
|                 return invert ? -ret4 : ret4; |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     protected class NodeBuilder implements Callable<Node> { |  | ||||||
|         protected List<INDArray> list; |  | ||||||
|         protected List<Integer> indices; |  | ||||||
| 
 |  | ||||||
|         public NodeBuilder(List<INDArray> list, List<Integer> indices) { |  | ||||||
|             this.list = list; |  | ||||||
|             this.indices = indices; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public Node call() throws Exception { |  | ||||||
|             return buildFromPoints(list, indices); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private Node buildFromPoints(List<INDArray> points, List<Integer> indices) { |  | ||||||
|         Node ret = new Node(0, 0); |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         // nothing to sort here |  | ||||||
|         if (points.size() == 1) { |  | ||||||
|             ret.point = points.get(0); |  | ||||||
|             ret.index = indices.get(0); |  | ||||||
|             return ret; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // opening workspace, and creating it if that's the first call |  | ||||||
|        /* MemoryWorkspace workspace = |  | ||||||
|                 Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ |  | ||||||
| 
 |  | ||||||
|         INDArray items = Nd4j.vstack(points); |  | ||||||
|         int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); |  | ||||||
|         INDArray basePoint = points.get(randomPoint);//items.getRow(randomPoint); |  | ||||||
|         ret.point = basePoint; |  | ||||||
|         ret.index = indices.get(randomPoint); |  | ||||||
|         INDArray distancesArr = Nd4j.create(items.rows(), 1); |  | ||||||
| 
 |  | ||||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); |  | ||||||
| 
 |  | ||||||
|         double medianDistance = distancesArr.medianNumber().doubleValue(); |  | ||||||
| 
 |  | ||||||
|         ret.threshold = (float) medianDistance; |  | ||||||
| 
 |  | ||||||
|         List<INDArray> leftPoints = new ArrayList<>(); |  | ||||||
|         List<Integer> leftIndices = new ArrayList<>(); |  | ||||||
|         List<INDArray> rightPoints = new ArrayList<>(); |  | ||||||
|         List<Integer> rightIndices = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < distancesArr.length(); i++) { |  | ||||||
|             if (i == randomPoint) |  | ||||||
|                 continue; |  | ||||||
| 
 |  | ||||||
|             if (distancesArr.getDouble(i) < medianDistance) { |  | ||||||
|                 leftPoints.add(points.get(i)); |  | ||||||
|                 leftIndices.add(indices.get(i)); |  | ||||||
|             } else { |  | ||||||
|                 rightPoints.add(points.get(i)); |  | ||||||
|                 rightIndices.add(indices.get(i)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // closing workspace |  | ||||||
|         //workspace.notifyScopeLeft(); |  | ||||||
|         //log.info("Thread: {}; Workspace size: {} MB; ConstantCache: {}; ShapeCache: {}; TADCache: {}", Thread.currentThread().getId(), (int) (workspace.getCurrentSize() / 1024 / 1024 ), Nd4j.getConstantHandler().getCachedBytes(), Nd4j.getShapeInfoProvider().getCachedBytes(), Nd4j.getExecutioner().getTADManager().getCachedBytes()); |  | ||||||
| 
 |  | ||||||
|         if (workers > 1) { |  | ||||||
|             if (!leftPoints.isEmpty()) |  | ||||||
|                 ret.futureLeft = executorService.submit(new NodeBuilder(leftPoints, leftIndices)); // = buildFromPoints(leftPoints); |  | ||||||
| 
 |  | ||||||
|             if (!rightPoints.isEmpty()) |  | ||||||
|                 ret.futureRight = executorService.submit(new NodeBuilder(rightPoints, rightIndices)); |  | ||||||
|         } else { |  | ||||||
|             if (!leftPoints.isEmpty()) |  | ||||||
|                 ret.left = buildFromPoints(leftPoints, leftIndices); |  | ||||||
| 
 |  | ||||||
|             if (!rightPoints.isEmpty()) |  | ||||||
|                 ret.right = buildFromPoints(rightPoints, rightIndices); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private Node buildFromPoints(INDArray items) { |  | ||||||
|         if (executorService == null && items == this.items && workers > 1) { |  | ||||||
|             final val deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread(); |  | ||||||
| 
 |  | ||||||
|             executorService = Executors.newFixedThreadPool(workers, new ThreadFactory() { |  | ||||||
|                 @Override |  | ||||||
|                 public Thread newThread(final Runnable r) { |  | ||||||
|                     Thread t = new Thread(new Runnable() { |  | ||||||
| 
 |  | ||||||
|                         @Override |  | ||||||
|                         public void run() { |  | ||||||
|                             Nd4j.getAffinityManager().unsafeSetDevice(deviceId); |  | ||||||
|                             r.run(); |  | ||||||
|                         } |  | ||||||
|                     }); |  | ||||||
| 
 |  | ||||||
|                     t.setDaemon(true); |  | ||||||
|                     t.setName("VPTree thread"); |  | ||||||
| 
 |  | ||||||
|                     return t; |  | ||||||
|                 } |  | ||||||
|             }); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         final Node ret = new Node(0, 0); |  | ||||||
|         size.incrementAndGet(); |  | ||||||
| 
 |  | ||||||
|         /*workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1) |  | ||||||
|                 .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP) |  | ||||||
|                 .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT) |  | ||||||
|                 .policySpill(SpillPolicy.REALLOCATE).build(); |  | ||||||
| 
 |  | ||||||
|         // opening workspace |  | ||||||
|         MemoryWorkspace workspace = |  | ||||||
|                 Nd4j.getWorkspaceManager().getAndActivateWorkspace(workspaceConfiguration, "VPTREE_WORSKPACE");*/ |  | ||||||
| 
 |  | ||||||
|         int randomPoint = MathUtils.randomNumberBetween(0, items.rows() - 1, Nd4j.getRandom()); |  | ||||||
|         INDArray basePoint = items.getRow(randomPoint, true); |  | ||||||
|         INDArray distancesArr = Nd4j.create(items.rows(), 1); |  | ||||||
|         ret.point = basePoint; |  | ||||||
|         ret.index = randomPoint; |  | ||||||
| 
 |  | ||||||
|         calcDistancesRelativeTo(items, basePoint, distancesArr); |  | ||||||
| 
 |  | ||||||
|         double medianDistance = distancesArr.medianNumber().doubleValue(); |  | ||||||
| 
 |  | ||||||
|         ret.threshold = (float) medianDistance; |  | ||||||
| 
 |  | ||||||
|         List<INDArray> leftPoints = new ArrayList<>(); |  | ||||||
|         List<Integer> leftIndices = new ArrayList<>(); |  | ||||||
|         List<INDArray> rightPoints = new ArrayList<>(); |  | ||||||
|         List<Integer> rightIndices = new ArrayList<>(); |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < distancesArr.length(); i++) { |  | ||||||
|             if (i == randomPoint) |  | ||||||
|                 continue; |  | ||||||
| 
 |  | ||||||
|             if (distancesArr.getDouble(i) < medianDistance) { |  | ||||||
|                 leftPoints.add(items.getRow(i, true)); |  | ||||||
|                 leftIndices.add(i); |  | ||||||
|             } else { |  | ||||||
|                 rightPoints.add(items.getRow(i, true)); |  | ||||||
|                 rightIndices.add(i); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         // closing workspace |  | ||||||
|         //workspace.notifyScopeLeft(); |  | ||||||
|         //workspace.destroyWorkspace(true); |  | ||||||
| 
 |  | ||||||
|         if (!leftPoints.isEmpty()) |  | ||||||
|             ret.left = buildFromPoints(leftPoints, leftIndices); |  | ||||||
| 
 |  | ||||||
|         if (!rightPoints.isEmpty()) |  | ||||||
|             ret.right = buildFromPoints(rightPoints, rightIndices); |  | ||||||
| 
 |  | ||||||
|         // destroy once again |  | ||||||
|         //workspace.destroyWorkspace(true); |  | ||||||
| 
 |  | ||||||
|         if (ret.left != null) |  | ||||||
|             ret.left.fetchFutures(); |  | ||||||
| 
 |  | ||||||
|         if (ret.right != null) |  | ||||||
|             ret.right.fetchFutures(); |  | ||||||
| 
 |  | ||||||
|         if (executorService != null) |  | ||||||
|             executorService.shutdown(); |  | ||||||
| 
 |  | ||||||
|         return ret; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances) { |  | ||||||
|         search(target, k, results, distances, true); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, |  | ||||||
|                        boolean filterEqual) { |  | ||||||
|         search(target, k, results, distances, filterEqual, false); |  | ||||||
|     } |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param target |  | ||||||
|      * @param k |  | ||||||
|      * @param results |  | ||||||
|      * @param distances |  | ||||||
|      */ |  | ||||||
|     public void search(@NonNull INDArray target, int k, List<DataPoint> results, List<Double> distances, |  | ||||||
|                        boolean filterEqual, boolean dropEdge) { |  | ||||||
|         if (items != null) |  | ||||||
|             if (!target.isVectorOrScalar() || target.columns() != items.columns() || target.rows() > 1) |  | ||||||
|                 throw new ND4JIllegalStateException("Target for search should have shape of [" + 1 + ", " |  | ||||||
|                         + items.columns() + "] but got " + Arrays.toString(target.shape()) + " instead"); |  | ||||||
| 
 |  | ||||||
|         k = Math.min(k, items.rows()); |  | ||||||
|         results.clear(); |  | ||||||
|         distances.clear(); |  | ||||||
| 
 |  | ||||||
|         PriorityQueue<HeapObject> pq = new PriorityQueue<>(items.rows(), new HeapObjectComparator()); |  | ||||||
| 
 |  | ||||||
|         search(root, target, k + (filterEqual ? 2 : 1), pq, Double.MAX_VALUE); |  | ||||||
| 
 |  | ||||||
|         while (!pq.isEmpty()) { |  | ||||||
|             HeapObject ho = pq.peek(); |  | ||||||
|             results.add(new DataPoint(ho.getIndex(), ho.getPoint())); |  | ||||||
|             distances.add(ho.getDistance()); |  | ||||||
|             pq.poll(); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         Collections.reverse(results); |  | ||||||
|         Collections.reverse(distances); |  | ||||||
| 
 |  | ||||||
|         if (dropEdge || results.size() > k) { |  | ||||||
|             if (filterEqual && distances.get(0) == 0.0) { |  | ||||||
|                 results.remove(0); |  | ||||||
|                 distances.remove(0); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             while (results.size() > k) { |  | ||||||
|                 results.remove(results.size() - 1); |  | ||||||
|                 distances.remove(distances.size() - 1); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     /** |  | ||||||
|      * |  | ||||||
|      * @param node |  | ||||||
|      * @param target |  | ||||||
|      * @param k |  | ||||||
|      * @param pq |  | ||||||
|      */ |  | ||||||
|     public void search(Node node, INDArray target, int k, PriorityQueue<HeapObject> pq, double cTau) { |  | ||||||
| 
 |  | ||||||
|         if (node == null) |  | ||||||
|             return; |  | ||||||
| 
 |  | ||||||
|         double tau = cTau; |  | ||||||
| 
 |  | ||||||
|         INDArray get = node.getPoint(); //items.getRow(node.getIndex()); |  | ||||||
|         double distance = distance(get, target); |  | ||||||
|         if (distance < tau) { |  | ||||||
|            if (pq.size() == k) |  | ||||||
|               pq.poll(); |  | ||||||
| 
 |  | ||||||
|             pq.add(new HeapObject(node.getIndex(), node.getPoint(), distance)); |  | ||||||
|             if (pq.size() == k) |  | ||||||
|                tau = pq.peek().getDistance(); |  | ||||||
|          } |  | ||||||
| 
 |  | ||||||
|         Node left = node.getLeft(); |  | ||||||
|         Node right = node.getRight(); |  | ||||||
| 
 |  | ||||||
|         if (left == null && right == null) |  | ||||||
|             return; |  | ||||||
| 
 |  | ||||||
|         if (distance < node.getThreshold()) { |  | ||||||
|             if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child first |  | ||||||
|                 search(left, target, k, pq, tau); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child |  | ||||||
|                 search(right, target, k, pq, tau); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|         } else { |  | ||||||
|             if (distance + tau >= node.getThreshold()) { // if there can still be neighbors outside the ball, recursively search right child first |  | ||||||
|                 search(right, target, k, pq, tau); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
|             if (distance - tau < node.getThreshold()) { // if there can still be neighbors inside the ball, recursively search left child |  | ||||||
|                 search(left, target, k, pq, tau); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     protected class HeapObjectComparator implements Comparator<HeapObject> { |  | ||||||
| 
 |  | ||||||
|         @Override |  | ||||||
|         public int compare(HeapObject o1, HeapObject o2) { |  | ||||||
|             return Double.compare(o2.getDistance(), o1.getDistance()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Data |  | ||||||
|     public static class Node implements Serializable { |  | ||||||
|         private static final long serialVersionUID = 2L; |  | ||||||
| 
 |  | ||||||
|         private int index; |  | ||||||
|         private float threshold; |  | ||||||
|         private Node left, right; |  | ||||||
|         private INDArray point; |  | ||||||
|         protected transient Future<Node> futureLeft; |  | ||||||
|         protected transient Future<Node> futureRight; |  | ||||||
| 
 |  | ||||||
|         public Node(int index, float threshold) { |  | ||||||
|             this.index = index; |  | ||||||
|             this.threshold = threshold; |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         public void fetchFutures() { |  | ||||||
|             try { |  | ||||||
|                 if (futureLeft != null) { |  | ||||||
|                     /*while (!futureLeft.isDone()) |  | ||||||
|                         Thread.sleep(100);*/ |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|                     left = futureLeft.get(); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
|                 if (futureRight != null) { |  | ||||||
|                     /*while (!futureRight.isDone()) |  | ||||||
|                         Thread.sleep(100);*/ |  | ||||||
| 
 |  | ||||||
|                     right = futureRight.get(); |  | ||||||
|                 } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|                 if (left != null) |  | ||||||
|                     left.fetchFutures(); |  | ||||||
| 
 |  | ||||||
|                 if (right != null) |  | ||||||
|                     right.fetchFutures(); |  | ||||||
|             } catch (Exception e) { |  | ||||||
|                 throw new RuntimeException(e); |  | ||||||
|             } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,79 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.vptree; |  | ||||||
| 
 |  | ||||||
| import lombok.Getter; |  | ||||||
| import org.deeplearning4j.clustering.sptree.DataPoint; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class VPTreeFillSearch { |  | ||||||
|     private VPTree vpTree; |  | ||||||
|     private int k; |  | ||||||
|     @Getter |  | ||||||
|     private List<DataPoint> results; |  | ||||||
|     @Getter |  | ||||||
|     private List<Double> distances; |  | ||||||
|     private INDArray target; |  | ||||||
| 
 |  | ||||||
|     public VPTreeFillSearch(VPTree vpTree, int k, INDArray target) { |  | ||||||
|         this.vpTree = vpTree; |  | ||||||
|         this.k = k; |  | ||||||
|         this.target = target; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     public void search() { |  | ||||||
|         results = new ArrayList<>(); |  | ||||||
|         distances = new ArrayList<>(); |  | ||||||
|         //initial search |  | ||||||
|         //vpTree.search(target,k,results,distances); |  | ||||||
| 
 |  | ||||||
|         //fill till there is k results |  | ||||||
|         //by going down the list |  | ||||||
|         //   if(results.size() < k) { |  | ||||||
|         INDArray distancesArr = Nd4j.create(vpTree.getItems().rows(), 1); |  | ||||||
|         vpTree.calcDistancesRelativeTo(target, distancesArr); |  | ||||||
|         INDArray[] sortWithIndices = Nd4j.sortWithIndices(distancesArr, 0, !vpTree.isInvert()); |  | ||||||
|         results.clear(); |  | ||||||
|         distances.clear(); |  | ||||||
|         if (vpTree.getItems().isVector()) { |  | ||||||
|             for (int i = 0; i < k; i++) { |  | ||||||
|                 int idx = sortWithIndices[0].getInt(i); |  | ||||||
|                 results.add(new DataPoint(idx, Nd4j.scalar(vpTree.getItems().getDouble(idx)))); |  | ||||||
|                 distances.add(sortWithIndices[1].getDouble(idx)); |  | ||||||
|             } |  | ||||||
|         } else { |  | ||||||
|             for (int i = 0; i < k; i++) { |  | ||||||
|                 int idx = sortWithIndices[0].getInt(i); |  | ||||||
|                 results.add(new DataPoint(idx, vpTree.getItems().getRow(idx))); |  | ||||||
|                 //distances.add(sortWithIndices[1].getDouble(idx)); |  | ||||||
|                 distances.add(sortWithIndices[1].getDouble(i)); |  | ||||||
|             } |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| } |  | ||||||
| @ -1,21 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.vptree; |  | ||||||
| @ -1,46 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.cluster; |  | ||||||
| 
 |  | ||||||
| import org.junit.Assert; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.List; |  | ||||||
| 
 |  | ||||||
| public class ClusterSetTest { |  | ||||||
|     @Test |  | ||||||
|     public void testGetMostPopulatedClusters() { |  | ||||||
|         ClusterSet clusterSet = new ClusterSet(false); |  | ||||||
|         List<Cluster> clusters = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < 5; i++) { |  | ||||||
|             Cluster cluster = new Cluster(); |  | ||||||
|             cluster.setPoints(Point.toPoints(Nd4j.randn(i + 1, 5))); |  | ||||||
|             clusters.add(cluster); |  | ||||||
|         } |  | ||||||
|         clusterSet.setClusters(clusters); |  | ||||||
|         List<Cluster> mostPopulatedClusters = clusterSet.getMostPopulatedClusters(5); |  | ||||||
|         for (int i = 0; i < 5; i++) { |  | ||||||
|             Assert.assertEquals(5 - i, mostPopulatedClusters.get(i).getPoints().size()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
| @ -1,422 +0,0 @@ | |||||||
| /* |  | ||||||
|  *  ****************************************************************************** |  | ||||||
|  *  * |  | ||||||
|  *  * |  | ||||||
|  *  * This program and the accompanying materials are made available under the |  | ||||||
|  *  * terms of the Apache License, Version 2.0 which is available at |  | ||||||
|  *  * https://www.apache.org/licenses/LICENSE-2.0. |  | ||||||
|  *  * |  | ||||||
|  *  *  See the NOTICE file distributed with this work for additional |  | ||||||
|  *  *  information regarding copyright ownership. |  | ||||||
|  *  * Unless required by applicable law or agreed to in writing, software |  | ||||||
|  *  * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |  | ||||||
|  *  * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |  | ||||||
|  *  * License for the specific language governing permissions and limitations |  | ||||||
|  *  * under the License. |  | ||||||
|  *  * |  | ||||||
|  *  * SPDX-License-Identifier: Apache-2.0 |  | ||||||
|  *  ***************************************************************************** |  | ||||||
|  */ |  | ||||||
| 
 |  | ||||||
| package org.deeplearning4j.clustering.kdtree; |  | ||||||
| 
 |  | ||||||
| import lombok.val; |  | ||||||
| import org.deeplearning4j.BaseDL4JTest; |  | ||||||
| import org.joda.time.Duration; |  | ||||||
| import org.junit.Before; |  | ||||||
| import org.junit.BeforeClass; |  | ||||||
| import org.junit.Ignore; |  | ||||||
| import org.junit.Test; |  | ||||||
| import org.nd4j.linalg.api.buffer.DataType; |  | ||||||
| import org.nd4j.linalg.api.ndarray.INDArray; |  | ||||||
| import org.nd4j.linalg.factory.Nd4j; |  | ||||||
| import org.nd4j.common.primitives.Pair; |  | ||||||
| import org.nd4j.shade.guava.base.Stopwatch; |  | ||||||
| import org.nd4j.shade.guava.primitives.Doubles; |  | ||||||
| import org.nd4j.shade.guava.primitives.Floats; |  | ||||||
| 
 |  | ||||||
| import java.util.ArrayList; |  | ||||||
| import java.util.Arrays; |  | ||||||
| import java.util.List; |  | ||||||
| import java.util.Random; |  | ||||||
| 
 |  | ||||||
| import static java.util.concurrent.TimeUnit.MILLISECONDS; |  | ||||||
| import static java.util.concurrent.TimeUnit.SECONDS; |  | ||||||
| import static org.junit.Assert.assertEquals; |  | ||||||
| import static org.junit.Assert.assertTrue; |  | ||||||
| 
 |  | ||||||
| public class KDTreeTest extends BaseDL4JTest { |  | ||||||
| 
 |  | ||||||
|     @Override |  | ||||||
|     public long getTimeoutMilliseconds() { |  | ||||||
|         return 120000L; |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     private KDTree kdTree; |  | ||||||
| 
 |  | ||||||
|     @BeforeClass |  | ||||||
|     public static void beforeClass(){ |  | ||||||
|         Nd4j.setDataType(DataType.FLOAT); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Before |  | ||||||
|     public void setUp() { |  | ||||||
|          kdTree = new KDTree(2); |  | ||||||
|         float[] data = new float[]{7,2}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{5,4}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{2,3}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{4,7}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{9,6}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{8,1}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testTree() { |  | ||||||
|         KDTree tree = new KDTree(2); |  | ||||||
|         INDArray half = Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT); |  | ||||||
|         INDArray one = Nd4j.create(new double[] {1, 1}, new long[]{1,2}).castTo(DataType.FLOAT); |  | ||||||
|         tree.insert(half); |  | ||||||
|         tree.insert(one); |  | ||||||
|         Pair<Double, INDArray> pair = tree.nn(Nd4j.create(new double[] {0.5, 0.5}, new long[]{1,2}).castTo(DataType.FLOAT)); |  | ||||||
|         assertEquals(half, pair.getValue()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testInsert() { |  | ||||||
|         int elements = 10; |  | ||||||
|         List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); |  | ||||||
| 
 |  | ||||||
|         KDTree kdTree = new KDTree(digits.size()); |  | ||||||
|         List<List<Double>> lists = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < elements; i++) { |  | ||||||
|             List<Double> thisList = new ArrayList<>(digits.size()); |  | ||||||
|             for (int k = 0; k < digits.size(); k++) { |  | ||||||
|                 thisList.add(digits.get(k) + i); |  | ||||||
|             } |  | ||||||
|             lists.add(thisList); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < elements; i++) { |  | ||||||
|             double[] features = Doubles.toArray(lists.get(i)); |  | ||||||
|             INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); |  | ||||||
|             kdTree.insert(ind); |  | ||||||
|             assertEquals(i + 1, kdTree.size()); |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testDelete() { |  | ||||||
|         int elements = 10; |  | ||||||
|         List<Double> digits = Arrays.asList(1.0, 0.0, 2.0, 3.0); |  | ||||||
| 
 |  | ||||||
|         KDTree kdTree = new KDTree(digits.size()); |  | ||||||
|         List<List<Double>> lists = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < elements; i++) { |  | ||||||
|             List<Double> thisList = new ArrayList<>(digits.size()); |  | ||||||
|             for (int k = 0; k < digits.size(); k++) { |  | ||||||
|                 thisList.add(digits.get(k) + i); |  | ||||||
|             } |  | ||||||
|             lists.add(thisList); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         INDArray toDelete = Nd4j.empty(DataType.DOUBLE), |  | ||||||
|                  leafToDelete = Nd4j.empty(DataType.DOUBLE); |  | ||||||
|         for (int i = 0; i < elements; i++) { |  | ||||||
|             double[] features = Doubles.toArray(lists.get(i)); |  | ||||||
|             INDArray ind = Nd4j.create(features, new long[]{1, features.length}, DataType.FLOAT); |  | ||||||
|             if (i == 1) |  | ||||||
|                 toDelete = ind; |  | ||||||
|             if (i == elements - 1) { |  | ||||||
|                 leafToDelete = ind; |  | ||||||
|             } |  | ||||||
|             kdTree.insert(ind); |  | ||||||
|             assertEquals(i + 1, kdTree.size()); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         kdTree.delete(toDelete); |  | ||||||
|         assertEquals(9, kdTree.size()); |  | ||||||
|         kdTree.delete(leafToDelete); |  | ||||||
|         assertEquals(8, kdTree.size()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testNN() { |  | ||||||
|         int n = 10; |  | ||||||
| 
 |  | ||||||
|         // make a KD-tree of dimension {#n} |  | ||||||
|         KDTree kdTree = new KDTree(n); |  | ||||||
|         for (int i = -1; i < n; i++) { |  | ||||||
|             // Insert a unit vector along each dimension |  | ||||||
|             List<Double> vec = new ArrayList<>(n); |  | ||||||
|             // i = -1 ensures the origin is in the Tree |  | ||||||
|             for (int k = 0; k < n; k++) { |  | ||||||
|                 vec.add((k == i) ? 1.0 : 0.0); |  | ||||||
|             } |  | ||||||
|             INDArray indVec = Nd4j.create(Doubles.toArray(vec), new long[]{1, vec.size()}, DataType.FLOAT); |  | ||||||
|             kdTree.insert(indVec); |  | ||||||
|         } |  | ||||||
|         Random rand = new Random(); |  | ||||||
| 
 |  | ||||||
|         // random point in the Hypercube |  | ||||||
|         List<Double> pt = new ArrayList(n); |  | ||||||
|         for (int k = 0; k < n; k++) { |  | ||||||
|             pt.add(rand.nextDouble()); |  | ||||||
|         } |  | ||||||
|         Pair<Double, INDArray> result = kdTree.nn(Nd4j.create(Doubles.toArray(pt), new long[]{1, pt.size()}, DataType.FLOAT)); |  | ||||||
| 
 |  | ||||||
|         // Always true for points in the unitary hypercube |  | ||||||
|         assertTrue(result.getKey() < Double.MAX_VALUE); |  | ||||||
| 
 |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN() { |  | ||||||
|         int dimensions = 512; |  | ||||||
|         int vectorsNo = isIntegrationTests() ? 50000 : 1000; |  | ||||||
|         // make a KD-tree of dimension {#dimensions} |  | ||||||
|         Stopwatch stopwatch = Stopwatch.createStarted(); |  | ||||||
|         KDTree kdTree = new KDTree(dimensions); |  | ||||||
|         for (int i = -1; i < vectorsNo; i++) { |  | ||||||
|             // Insert a unit vector along each dimension |  | ||||||
|             INDArray indVec = Nd4j.rand(DataType.FLOAT, 1,dimensions); |  | ||||||
|             kdTree.insert(indVec); |  | ||||||
|         } |  | ||||||
|         stopwatch.stop(); |  | ||||||
|         System.out.println("Time elapsed for " + kdTree.size() + " nodes construction is "+ stopwatch.elapsed(SECONDS)); |  | ||||||
| 
 |  | ||||||
|         Random rand = new Random(); |  | ||||||
|         // random point in the Hypercube |  | ||||||
|         List<Double> pt = new ArrayList(dimensions); |  | ||||||
|         for (int k = 0; k < dimensions; k++) { |  | ||||||
|             pt.add(rand.nextFloat() * 10.0); |  | ||||||
|         } |  | ||||||
|         stopwatch.reset(); |  | ||||||
|         stopwatch.start(); |  | ||||||
|         List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Floats.toArray(pt))), 20.0f); |  | ||||||
|         stopwatch.stop(); |  | ||||||
|         System.out.println("Time elapsed for Search is "+ stopwatch.elapsed(MILLISECONDS)); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_Simple() { |  | ||||||
|         int n = 2; |  | ||||||
|         KDTree kdTree = new KDTree(n); |  | ||||||
| 
 |  | ||||||
|         float[] data = new float[]{3,3}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{1,1}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
|         data = new float[]{2,2}; |  | ||||||
|         kdTree.insert(Nd4j.createFromArray(data)); |  | ||||||
| 
 |  | ||||||
|         data = new float[]{0,0}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 4.5f); |  | ||||||
| 
 |  | ||||||
|         assertEquals(1.0, result.get(0).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(0).getSecond().getDouble(1), 1e-5); |  | ||||||
| 
 |  | ||||||
|         assertEquals(2.0, result.get(1).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(1).getSecond().getDouble(1), 1e-5); |  | ||||||
| 
 |  | ||||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_1() { |  | ||||||
| 
 |  | ||||||
|         assertEquals(6, kdTree.size()); |  | ||||||
| 
 |  | ||||||
|         float[] data = new float[]{8,1}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); |  | ||||||
|         assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(3).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(3).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(4).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(4).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(5).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(5).getSecond().getFloat(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_2() { |  | ||||||
|         float[] data = new float[]{8, 1}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); |  | ||||||
|         assertEquals(8.0, result.get(0).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(0).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(1).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(1).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(2).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_3() { |  | ||||||
| 
 |  | ||||||
|         float[] data = new float[]{2, 3}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); |  | ||||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_4() { |  | ||||||
|         float[] data = new float[]{2, 3}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); |  | ||||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testKNN_5() { |  | ||||||
|         float[] data = new float[]{2, 3}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); |  | ||||||
|         assertEquals(2.0, result.get(0).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(0).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(2).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(2).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(3).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(3).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(8.0, result.get(4).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(4).getSecond().getFloat(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(5).getSecond().getFloat(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(5).getSecond().getFloat(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void test_KNN_6() { |  | ||||||
|         float[] data = new float[]{4, 6}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 10.0f); |  | ||||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void test_KNN_7() { |  | ||||||
|         float[] data = new float[]{4, 6}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 5.0f); |  | ||||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void test_KNN_8() { |  | ||||||
|         float[] data = new float[]{4, 6}; |  | ||||||
|         List<Pair<Float, INDArray>> result = kdTree.knn(Nd4j.createFromArray(data), 20.0f); |  | ||||||
|         assertEquals(4.0, result.get(0).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(0).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(5.0, result.get(1).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(4.0, result.get(1).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(2).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(3.0, result.get(2).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(7.0, result.get(3).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(2.0, result.get(3).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(9.0, result.get(4).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(6.0, result.get(4).getSecond().getDouble(1), 1e-5); |  | ||||||
|         assertEquals(8.0, result.get(5).getSecond().getDouble(0), 1e-5); |  | ||||||
|         assertEquals(1.0, result.get(5).getSecond().getDouble(1), 1e-5); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Test |  | ||||||
|     public void testNoDuplicates() { |  | ||||||
|         int N = 100; |  | ||||||
|         KDTree bigTree = new KDTree(2); |  | ||||||
| 
 |  | ||||||
|         List<INDArray> points = new ArrayList<>(); |  | ||||||
|         for (int i = 0; i < N; ++i) { |  | ||||||
|             double[] data = new double[]{i, i}; |  | ||||||
|             points.add(Nd4j.createFromArray(data)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         for (int i = 0; i < N; ++i) { |  | ||||||
|             bigTree.insert(points.get(i)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         assertEquals(N, bigTree.size()); |  | ||||||
| 
 |  | ||||||
|         INDArray node = Nd4j.empty(DataType.DOUBLE); |  | ||||||
|         for (int i = 0; i < N; ++i) { |  | ||||||
|             node = bigTree.delete(node.isEmpty() ? points.get(i) : node); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         assertEquals(0, bigTree.size()); |  | ||||||
|     } |  | ||||||
| 
 |  | ||||||
|     @Ignore |  | ||||||
|     @Test |  | ||||||
|     public void performanceTest() { |  | ||||||
|         int n = 2; |  | ||||||
|         int num = 100000; |  | ||||||
|         // make a KD-tree of dimension {#n} |  | ||||||
|         long start = System.currentTimeMillis(); |  | ||||||
|         KDTree kdTree = new KDTree(n); |  | ||||||
|         INDArray inputArrray = Nd4j.randn(DataType.DOUBLE, num, n); |  | ||||||
|         for (int  i = 0 ; i < num; ++i) { |  | ||||||
|             kdTree.insert(inputArrray.getRow(i)); |  | ||||||
|         } |  | ||||||
| 
 |  | ||||||
|         long end = System.currentTimeMillis(); |  | ||||||
|         Duration duration = new Duration(start, end); |  | ||||||
|         System.out.println("Elapsed time for tree construction " + duration.getStandardSeconds() + " " + duration.getMillis()); |  | ||||||
| 
 |  | ||||||
|         List<Float> pt = new ArrayList(num); |  | ||||||
|         for (int k = 0; k < n; k++) { |  | ||||||
|             pt.add((float)(num / 2)); |  | ||||||
|         } |  | ||||||
|         start = System.currentTimeMillis(); |  | ||||||
|         List<Pair<Float, INDArray>> list = kdTree.knn(Nd4j.create(Nd4j.createBuffer(Doubles.toArray(pt))), 20.0f); |  | ||||||
|         end = System.currentTimeMillis(); |  | ||||||
|         duration = new Duration(start, end); |  | ||||||
|         long elapsed = end - start; |  | ||||||
|         System.out.println("Elapsed time for tree search " + duration.getStandardSeconds() + " " + duration.getMillis()); |  | ||||||
|         for (val pair : list) { |  | ||||||
|             System.out.println(pair.getFirst() + " " + pair.getSecond()) ; |  | ||||||
|         } |  | ||||||
|     } |  | ||||||
| } |  | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user