diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
deleted file mode 100644
index c69e1abcb..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/pom.xml
+++ /dev/null
@@ -1,64 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.datavec
- datavec-spark-inference-parent
- 1.0.0-SNAPSHOT
-
-
- datavec-spark-inference-client
-
- datavec-spark-inference-client
-
-
-
- org.datavec
- datavec-spark-inference-server_2.11
- 1.0.0-SNAPSHOT
- test
-
-
- org.datavec
- datavec-spark-inference-model
- ${project.parent.version}
-
-
- com.mashape.unirest
- unirest-java
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java
deleted file mode 100644
index 8a346b096..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/main/java/org/datavec/spark/inference/client/DataVecTransformClient.java
+++ /dev/null
@@ -1,292 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.client;
-
-
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import com.mashape.unirest.http.exceptions.UnirestException;
-import lombok.AllArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.model.*;
-import org.datavec.spark.inference.model.service.DataVecTransformService;
-import org.nd4j.shade.jackson.core.JsonProcessingException;
-
-import java.io.IOException;
-
-@AllArgsConstructor
-@Slf4j
-public class DataVecTransformClient implements DataVecTransformService {
- private String url;
-
- static {
- // Only one time
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
- });
- }
-
- /**
- * @param transformProcess
- */
- @Override
- public void setCSVTransformProcess(TransformProcess transformProcess) {
- try {
- String s = transformProcess.toJson();
- Unirest.post(url + "/transformprocess").header("accept", "application/json")
- .header("Content-Type", "application/json").body(s).asJson();
-
- } catch (UnirestException e) {
- log.error("Error in setCSVTransformProcess()", e);
- }
- }
-
- @Override
- public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- /**
- * @return
- */
- @Override
- public TransformProcess getCSVTransformProcess() {
- try {
- String s = Unirest.get(url + "/transformprocess").header("accept", "application/json")
- .header("Content-Type", "application/json").asString().getBody();
- return TransformProcess.fromJson(s);
- } catch (UnirestException e) {
- log.error("Error in getCSVTransformProcess()",e);
- }
-
- return null;
- }
-
- @Override
- public ImageTransformProcess getImageTransformProcess() {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- /**
- * @param transform
- * @return
- */
- @Override
- public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
- try {
- SingleCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
- .header("accept", "application/json")
- .header("Content-Type", "application/json")
- .body(transform).asObject(SingleCSVRecord.class).getBody();
- return singleCsvRecord;
- } catch (UnirestException e) {
- log.error("Error in transformIncremental(SingleCSVRecord)",e);
- }
- return null;
- }
-
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
- try {
- SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"TRUE")
- .body(batchCSVRecord)
- .asObject(SequenceBatchCSVRecord.class)
- .getBody();
- return batchCSVRecord1;
- } catch (UnirestException e) {
- log.error("",e);
- }
-
- return null;
- }
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
- try {
- BatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform").header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"FALSE")
- .body(batchCSVRecord)
- .asObject(BatchCSVRecord.class)
- .getBody();
- return batchCSVRecord1;
- } catch (UnirestException e) {
- log.error("Error in transform(BatchCSVRecord)", e);
- }
-
- return null;
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
- try {
- Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
- .header("Content-Type", "application/json").body(batchCSVRecord)
- .asObject(Base64NDArrayBody.class).getBody();
- return batchArray1;
- } catch (UnirestException e) {
- log.error("Error in transformArray(BatchCSVRecord)",e);
- }
-
- return null;
- }
-
- /**
- * @param singleCsvRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
- try {
- Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
- return array;
- } catch (UnirestException e) {
- log.error("Error in transformArrayIncremental(SingleCSVRecord)",e);
- }
-
- return null;
- }
-
- @Override
- public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- /**
- * @param singleCsvRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
- try {
- Base64NDArrayBody array = Unirest.post(url + "/transformincrementalarray")
- .header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"true")
- .body(singleCsvRecord).asObject(Base64NDArrayBody.class).getBody();
- return array;
- } catch (UnirestException e) {
- log.error("Error in transformSequenceArrayIncremental",e);
- }
-
- return null;
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
- try {
- Base64NDArrayBody batchArray1 = Unirest.post(url + "/transformarray").header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"true")
- .body(batchCSVRecord)
- .asObject(Base64NDArrayBody.class).getBody();
- return batchArray1;
- } catch (UnirestException e) {
- log.error("Error in transformSequenceArray",e);
- }
-
- return null;
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
- try {
- SequenceBatchCSVRecord batchCSVRecord1 = Unirest.post(url + "/transform")
- .header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"true")
- .body(batchCSVRecord)
- .asObject(SequenceBatchCSVRecord.class).getBody();
- return batchCSVRecord1;
- } catch (UnirestException e) {
- log.error("Error in transformSequence");
- }
-
- return null;
- }
-
- /**
- * @param transform
- * @return
- */
- @Override
- public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
- try {
- SequenceBatchCSVRecord singleCsvRecord = Unirest.post(url + "/transformincremental")
- .header("accept", "application/json")
- .header("Content-Type", "application/json")
- .header(SEQUENCE_OR_NOT_HEADER,"true")
- .body(transform).asObject(SequenceBatchCSVRecord.class).getBody();
- return singleCsvRecord;
- } catch (UnirestException e) {
- log.error("Error in transformSequenceIncremental");
- }
- return null;
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java
deleted file mode 100644
index de2970b27..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/AssertTestsExtendBaseClass.java
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-package org.datavec.transform.client;
-
-import lombok.extern.slf4j.Slf4j;
-import org.nd4j.common.tests.AbstractAssertTestsClass;
-import org.nd4j.common.tests.BaseND4JTest;
-import java.util.*;
-
-@Slf4j
-public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
-
- @Override
- protected Set> getExclusions() {
- //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
- return new HashSet<>();
- }
-
- @Override
- protected String getPackageName() {
- return "org.datavec.transform.client";
- }
-
- @Override
- protected Class> getBaseClass() {
- return BaseND4JTest.class;
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java
deleted file mode 100644
index 6619ec443..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/java/org/datavec/transform/client/DataVecTransformClientTest.java
+++ /dev/null
@@ -1,139 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.transform.client;
-
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.spark.inference.server.CSVSparkTransformServer;
-import org.datavec.spark.inference.client.DataVecTransformClient;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.File;
-import java.io.IOException;
-import java.net.ServerSocket;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.UUID;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assume.assumeNotNull;
-
-public class DataVecTransformClientTest {
- private static CSVSparkTransformServer server;
- private static int port = getAvailablePort();
- private static DataVecTransformClient client;
- private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- private static TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
- private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
-
- @BeforeClass
- public static void beforeClass() throws Exception {
- FileUtils.write(fileSave, transformProcess.toJson());
- fileSave.deleteOnExit();
- server = new CSVSparkTransformServer();
- server.runMain(new String[] {"-dp", String.valueOf(port)});
-
- client = new DataVecTransformClient("http://localhost:" + port);
- client.setCSVTransformProcess(transformProcess);
- }
-
- @AfterClass
- public static void afterClass() throws Exception {
- server.stop();
- }
-
-
- @Test
- public void testSequenceClient() {
- SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
- SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
-
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
- List batchCSVRecordList = new ArrayList<>();
- for(int i = 0; i < 5; i++) {
- batchCSVRecordList.add(batchCSVRecord);
- }
-
- sequenceBatchCSVRecord.add(batchCSVRecordList);
-
- SequenceBatchCSVRecord sequenceBatchCSVRecord1 = client.transformSequence(sequenceBatchCSVRecord);
- assumeNotNull(sequenceBatchCSVRecord1);
-
- Base64NDArrayBody array = client.transformSequenceArray(sequenceBatchCSVRecord);
- assumeNotNull(array);
-
- Base64NDArrayBody incrementalBody = client.transformSequenceArrayIncremental(batchCSVRecord);
- assumeNotNull(incrementalBody);
-
- Base64NDArrayBody incrementalSequenceBody = client.transformSequenceArrayIncremental(batchCSVRecord);
- assumeNotNull(incrementalSequenceBody);
- }
-
- @Test
- public void testRecord() throws Exception {
- SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
- SingleCSVRecord transformed = client.transformIncremental(singleCsvRecord);
- assertEquals(singleCsvRecord.getValues().size(), transformed.getValues().size());
- Base64NDArrayBody body = client.transformArrayIncremental(singleCsvRecord);
- INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
- assumeNotNull(arr);
- }
-
- @Test
- public void testBatchRecord() throws Exception {
- SingleCSVRecord singleCsvRecord = new SingleCSVRecord(new String[] {"0", "0"});
-
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord(Arrays.asList(singleCsvRecord, singleCsvRecord));
- BatchCSVRecord batchCSVRecord1 = client.transform(batchCSVRecord);
- assertEquals(batchCSVRecord.getRecords().size(), batchCSVRecord1.getRecords().size());
-
- Base64NDArrayBody body = client.transformArray(batchCSVRecord);
- INDArray arr = Nd4jBase64.fromBase64(body.getNdarray());
- assumeNotNull(arr);
- }
-
-
-
- public static int getAvailablePort() {
- try {
- ServerSocket socket = new ServerSocket(0);
- try {
- return socket.getLocalPort();
- } finally {
- socket.close();
- }
- } catch (IOException e) {
- throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
- }
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf
deleted file mode 100644
index dbac92d83..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-client/src/test/resources/application.conf
+++ /dev/null
@@ -1,6 +0,0 @@
-play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
-play.modules.enabled += io.skymind.skil.service.PredictionModule
-play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
-play.server.pidfile.path=/tmp/RUNNING_PID
-
-play.server.http.port = 9600
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml
deleted file mode 100644
index fe9ca985a..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/pom.xml
+++ /dev/null
@@ -1,63 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.datavec
- datavec-spark-inference-parent
- 1.0.0-SNAPSHOT
-
-
- datavec-spark-inference-model
-
- datavec-spark-inference-model
-
-
-
- org.datavec
- datavec-api
- ${datavec.version}
-
-
- org.datavec
- datavec-data-image
-
-
- org.datavec
- datavec-local
- ${project.version}
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java
deleted file mode 100644
index e081708e0..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/CSVSparkTransform.java
+++ /dev/null
@@ -1,286 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.extern.slf4j.Slf4j;
-import lombok.val;
-import org.apache.arrow.memory.BufferAllocator;
-import org.apache.arrow.memory.RootAllocator;
-import org.apache.arrow.vector.FieldVector;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.util.ndarray.RecordConverter;
-import org.datavec.api.writable.Writable;
-import org.datavec.arrow.ArrowConverter;
-import org.datavec.arrow.recordreader.ArrowWritableRecordBatch;
-import org.datavec.arrow.recordreader.ArrowWritableRecordTimeSeriesBatch;
-import org.datavec.local.transforms.LocalTransformExecutor;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.nd4j.linalg.api.buffer.DataType;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.IOException;
-import java.util.Arrays;
-import java.util.List;
-
-import static org.datavec.arrow.ArrowConverter.*;
-import static org.datavec.local.transforms.LocalTransformExecutor.execute;
-import static org.datavec.local.transforms.LocalTransformExecutor.executeToSequence;
-
-@AllArgsConstructor
-@Slf4j
-public class CSVSparkTransform {
- @Getter
- private TransformProcess transformProcess;
- private static BufferAllocator bufferAllocator = new RootAllocator(Long.MAX_VALUE);
-
- /**
- * Convert a raw record via
- * the {@link TransformProcess}
- * to a base 64ed ndarray
- * @param batch the record to convert
- * @return teh base 64ed ndarray
- * @throws IOException
- */
- public Base64NDArrayBody toArray(BatchCSVRecord batch) throws IOException {
- List> converted = execute(toArrowWritables(toArrowColumnsString(
- bufferAllocator,transformProcess.getInitialSchema(),
- batch.getRecordsAsString()),
- transformProcess.getInitialSchema()),transformProcess);
-
- ArrowWritableRecordBatch arrowRecordBatch = (ArrowWritableRecordBatch) converted;
- INDArray convert = ArrowConverter.toArray(arrowRecordBatch);
- return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
- }
-
- /**
- * Convert a raw record via
- * the {@link TransformProcess}
- * to a base 64ed ndarray
- * @param record the record to convert
- * @return the base 64ed ndarray
- * @throws IOException
- */
- public Base64NDArrayBody toArray(SingleCSVRecord record) throws IOException {
- List record2 = toArrowWritablesSingle(
- toArrowColumnsStringSingle(bufferAllocator,
- transformProcess.getInitialSchema(),record.getValues()),
- transformProcess.getInitialSchema());
- List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
- INDArray convert = RecordConverter.toArray(DataType.DOUBLE, finalRecord);
- return new Base64NDArrayBody(Nd4jBase64.base64String(convert));
- }
-
- /**
- * Runs the transform process
- * @param batch the record to transform
- * @return the transformed record
- */
- public BatchCSVRecord transform(BatchCSVRecord batch) {
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- List> converted = execute(toArrowWritables(toArrowColumnsString(
- bufferAllocator,transformProcess.getInitialSchema(),
- batch.getRecordsAsString()),
- transformProcess.getInitialSchema()),transformProcess);
- int numCols = converted.get(0).size();
- for (int row = 0; row < converted.size(); row++) {
- String[] values = new String[numCols];
- for (int i = 0; i < values.length; i++)
- values[i] = converted.get(row).get(i).toString();
- batchCSVRecord.add(new SingleCSVRecord(values));
- }
-
- return batchCSVRecord;
-
- }
-
- /**
- * Runs the transform process
- * @param record the record to transform
- * @return the transformed record
- */
- public SingleCSVRecord transform(SingleCSVRecord record) {
- List record2 = toArrowWritablesSingle(
- toArrowColumnsStringSingle(bufferAllocator,
- transformProcess.getInitialSchema(),record.getValues()),
- transformProcess.getInitialSchema());
- List finalRecord = execute(Arrays.asList(record2),transformProcess).get(0);
- String[] values = new String[finalRecord.size()];
- for (int i = 0; i < values.length; i++)
- values[i] = finalRecord.get(i).toString();
- return new SingleCSVRecord(values);
-
- }
-
- /**
- *
- * @param transform
- * @return
- */
- public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
- /**
- * Sequence schema?
- */
- List>> converted = executeToSequence(
- toArrowWritables(toArrowColumnsStringTimeSeries(
- bufferAllocator, transformProcess.getInitialSchema(),
- Arrays.asList(transform.getRecordsAsString())),
- transformProcess.getInitialSchema()), transformProcess);
-
- SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
- for (int i = 0; i < converted.size(); i++) {
- BatchCSVRecord batchCSVRecord1 = BatchCSVRecord.fromWritables(converted.get(i));
- batchCSVRecord.add(Arrays.asList(batchCSVRecord1));
- }
-
- return batchCSVRecord;
- }
-
- /**
- *
- * @param batchCSVRecordSequence
- * @return
- */
- public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecordSequence) {
- List>> recordsAsString = batchCSVRecordSequence.getRecordsAsString();
- boolean allSameLength = true;
- Integer length = null;
- for(List> record : recordsAsString) {
- if(length == null) {
- length = record.size();
- }
- else if(record.size() != length) {
- allSameLength = false;
- }
- }
-
- if(allSameLength) {
- List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), recordsAsString);
- ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,
- transformProcess.getInitialSchema(),
- recordsAsString.get(0).get(0).size());
- val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
- return SequenceBatchCSVRecord.fromWritables(transformed);
- }
-
- else {
- val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
- return SequenceBatchCSVRecord.fromWritables(transformed);
-
- }
- }
-
- /**
- * TODO: optimize
- * @param batchCSVRecordSequence
- * @return
- */
- public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecordSequence) {
- List>> strings = batchCSVRecordSequence.getRecordsAsString();
- boolean allSameLength = true;
- Integer length = null;
- for(List> record : strings) {
- if(length == null) {
- length = record.size();
- }
- else if(record.size() != length) {
- allSameLength = false;
- }
- }
-
- if(allSameLength) {
- List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
- ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
- val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
- INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
- try {
- return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- }
-
- else {
- val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecordSequence.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
- INDArray arr = RecordConverter.toTensor(transformed).reshape(strings.size(),strings.get(0).get(0).size(),strings.get(0).size());
- try {
- return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
- } catch (IOException e) {
- throw new IllegalStateException(e);
- }
- }
-
- }
-
- /**
- *
- * @param singleCsvRecord
- * @return
- */
- public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
- List>> converted = executeToSequence(toArrowWritables(toArrowColumnsString(
- bufferAllocator,transformProcess.getInitialSchema(),
- singleCsvRecord.getRecordsAsString()),
- transformProcess.getInitialSchema()),transformProcess);
- ArrowWritableRecordTimeSeriesBatch arrowWritableRecordBatch = (ArrowWritableRecordTimeSeriesBatch) converted;
- INDArray arr = RecordConverter.toTensor(arrowWritableRecordBatch);
- try {
- return new Base64NDArrayBody(Nd4jBase64.base64String(arr));
- } catch (IOException e) {
- log.error("",e);
- }
-
- return null;
- }
-
- public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
- List>> strings = batchCSVRecord.getRecordsAsString();
- boolean allSameLength = true;
- Integer length = null;
- for(List> record : strings) {
- if(length == null) {
- length = record.size();
- }
- else if(record.size() != length) {
- allSameLength = false;
- }
- }
-
- if(allSameLength) {
- List fieldVectors = toArrowColumnsStringTimeSeries(bufferAllocator, transformProcess.getInitialSchema(), strings);
- ArrowWritableRecordTimeSeriesBatch arrowWritableRecordTimeSeriesBatch = new ArrowWritableRecordTimeSeriesBatch(fieldVectors,transformProcess.getInitialSchema(),strings.get(0).get(0).size());
- val transformed = LocalTransformExecutor.executeSequenceToSequence(arrowWritableRecordTimeSeriesBatch,transformProcess);
- return SequenceBatchCSVRecord.fromWritables(transformed);
- }
-
- else {
- val transformed = LocalTransformExecutor.executeSequenceToSequence(LocalTransformExecutor.convertStringInputTimeSeries(batchCSVRecord.getRecordsAsString(),transformProcess.getInitialSchema()),transformProcess);
- return SequenceBatchCSVRecord.fromWritables(transformed);
-
- }
-
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java
deleted file mode 100644
index a004c439b..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/ImageSparkTransform.java
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import org.datavec.image.data.ImageWritable;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchImageRecord;
-import org.datavec.spark.inference.model.model.SingleImageRecord;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-@AllArgsConstructor
-public class ImageSparkTransform {
- @Getter
- private ImageTransformProcess imageTransformProcess;
-
- public Base64NDArrayBody toArray(SingleImageRecord record) throws IOException {
- ImageWritable record2 = imageTransformProcess.transformFileUriToInput(record.getUri());
- INDArray finalRecord = imageTransformProcess.executeArray(record2);
-
- return new Base64NDArrayBody(Nd4jBase64.base64String(finalRecord));
- }
-
- public Base64NDArrayBody toArray(BatchImageRecord batch) throws IOException {
- List records = new ArrayList<>();
-
- for (SingleImageRecord imgRecord : batch.getRecords()) {
- ImageWritable record2 = imageTransformProcess.transformFileUriToInput(imgRecord.getUri());
- INDArray finalRecord = imageTransformProcess.executeArray(record2);
- records.add(finalRecord);
- }
-
- INDArray array = Nd4j.concat(0, records.toArray(new INDArray[records.size()]));
-
- return new Base64NDArrayBody(Nd4jBase64.base64String(array));
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java
deleted file mode 100644
index 0d6c680ad..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/Base64NDArrayBody.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class Base64NDArrayBody {
- private String ndarray;
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java
deleted file mode 100644
index 82ecedc51..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchCSVRecord.java
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.datavec.api.writable.Writable;
-import org.nd4j.linalg.dataset.DataSet;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-
-@Data
-@AllArgsConstructor
-@Builder
-@NoArgsConstructor
-public class BatchCSVRecord implements Serializable {
- private List records;
-
-
- /**
- * Get the records as a list of strings
- * (basically the underlying values for
- * {@link SingleCSVRecord})
- * @return
- */
- public List> getRecordsAsString() {
- if(records == null)
- records = new ArrayList<>();
- List> ret = new ArrayList<>();
- for(SingleCSVRecord csvRecord : records) {
- ret.add(csvRecord.getValues());
- }
- return ret;
- }
-
-
- /**
- * Create a batch csv record
- * from a list of writables.
- * @param batch
- * @return
- */
- public static BatchCSVRecord fromWritables(List> batch) {
- List records = new ArrayList<>(batch.size());
- for(List list : batch) {
- List add = new ArrayList<>(list.size());
- for(Writable writable : list) {
- add.add(writable.toString());
- }
- records.add(new SingleCSVRecord(add));
- }
-
- return BatchCSVRecord.builder().records(records).build();
- }
-
-
- /**
- * Add a record
- * @param record
- */
- public void add(SingleCSVRecord record) {
- if (records == null)
- records = new ArrayList<>();
- records.add(record);
- }
-
-
- /**
- * Return a batch record based on a dataset
- * @param dataSet the dataset to get the batch record for
- * @return the batch record
- */
- public static BatchCSVRecord fromDataSet(DataSet dataSet) {
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < dataSet.numExamples(); i++) {
- batchCSVRecord.add(SingleCSVRecord.fromRow(dataSet.get(i)));
- }
-
- return batchCSVRecord;
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java
deleted file mode 100644
index ff101c659..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/BatchImageRecord.java
+++ /dev/null
@@ -1,50 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.net.URI;
-import java.util.ArrayList;
-import java.util.List;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class BatchImageRecord {
- private List records;
-
- /**
- * Add a record
- * @param record
- */
- public void add(SingleImageRecord record) {
- if (records == null)
- records = new ArrayList<>();
- records.add(record);
- }
-
- public void add(URI uri) {
- this.add(new SingleImageRecord(uri));
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java
deleted file mode 100644
index eed4fac59..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SequenceBatchCSVRecord.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.datavec.api.writable.Writable;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.dataset.MultiDataSet;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-@Data
-@AllArgsConstructor
-@Builder
-@NoArgsConstructor
-public class SequenceBatchCSVRecord implements Serializable {
- private List> records;
-
- /**
- * Add a record
- * @param record
- */
- public void add(List record) {
- if (records == null)
- records = new ArrayList<>();
- records.add(record);
- }
-
- /**
- * Get the records as a list of strings directly
- * (this basically "unpacks" the objects)
- * @return
- */
- public List>> getRecordsAsString() {
- if(records == null)
- Collections.emptyList();
- List>> ret = new ArrayList<>(records.size());
- for(List record : records) {
- List> add = new ArrayList<>();
- for(BatchCSVRecord batchCSVRecord : record) {
- for (SingleCSVRecord singleCSVRecord : batchCSVRecord.getRecords()) {
- add.add(singleCSVRecord.getValues());
- }
- }
-
- ret.add(add);
- }
-
- return ret;
- }
-
- /**
- * Convert a writables time series to a sequence batch
- * @param input
- * @return
- */
- public static SequenceBatchCSVRecord fromWritables(List>> input) {
- SequenceBatchCSVRecord ret = new SequenceBatchCSVRecord();
- for(int i = 0; i < input.size(); i++) {
- ret.add(Arrays.asList(BatchCSVRecord.fromWritables(input.get(i))));
- }
-
- return ret;
- }
-
-
- /**
- * Return a batch record based on a dataset
- * @param dataSet the dataset to get the batch record for
- * @return the batch record
- */
- public static SequenceBatchCSVRecord fromDataSet(MultiDataSet dataSet) {
- SequenceBatchCSVRecord batchCSVRecord = new SequenceBatchCSVRecord();
- for (int i = 0; i < dataSet.numFeatureArrays(); i++) {
- batchCSVRecord.add(Arrays.asList(BatchCSVRecord.fromDataSet(new DataSet(dataSet.getFeatures(i),dataSet.getLabels(i)))));
- }
-
- return batchCSVRecord;
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java
deleted file mode 100644
index 575a91918..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleCSVRecord.java
+++ /dev/null
@@ -1,95 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.nd4j.linalg.dataset.DataSet;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class SingleCSVRecord implements Serializable {
- private List values;
-
- /**
- * Create from an array of values uses list internally)
- * @param values
- */
- public SingleCSVRecord(String...values) {
- this.values = Arrays.asList(values);
- }
-
- /**
- * Instantiate a csv record from a vector
- * given either an input dataset and a
- * one hot matrix, the index will be appended to
- * the end of the record, or for regression
- * it will append all values in the labels
- * @param row the input vectors
- * @return the record from this {@link DataSet}
- */
- public static SingleCSVRecord fromRow(DataSet row) {
- if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
- throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
- if (!row.getLabels().isVector() && !row.getLabels().isScalar())
- throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
- //classification
- SingleCSVRecord record;
- int idx = 0;
- if (row.getLabels().sumNumber().doubleValue() == 1.0) {
- String[] values = new String[row.getFeatures().columns() + 1];
- for (int i = 0; i < row.getFeatures().length(); i++) {
- values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
- }
- int maxIdx = 0;
- for (int i = 0; i < row.getLabels().length(); i++) {
- if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
- maxIdx = i;
- }
- }
-
- values[idx++] = String.valueOf(maxIdx);
- record = new SingleCSVRecord(values);
- }
- //regression (any number of values)
- else {
- String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
- for (int i = 0; i < row.getFeatures().length(); i++) {
- values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
- }
- for (int i = 0; i < row.getLabels().length(); i++) {
- values[idx++] = String.valueOf(row.getLabels().getDouble(i));
- }
-
-
- record = new SingleCSVRecord(values);
-
- }
- return record;
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java
deleted file mode 100644
index 9fe3df042..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/model/SingleImageRecord.java
+++ /dev/null
@@ -1,34 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.net.URI;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class SingleImageRecord {
- private URI uri;
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java
deleted file mode 100644
index c23dd562c..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/main/java/org/datavec/spark/inference/model/service/DataVecTransformService.java
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.model.service;
-
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.model.*;
-
-import java.io.IOException;
-
-public interface DataVecTransformService {
-
- String SEQUENCE_OR_NOT_HEADER = "Sequence";
-
-
- /**
- *
- * @param transformProcess
- */
- void setCSVTransformProcess(TransformProcess transformProcess);
-
- /**
- *
- * @param imageTransformProcess
- */
- void setImageTransformProcess(ImageTransformProcess imageTransformProcess);
-
- /**
- *
- * @return
- */
- TransformProcess getCSVTransformProcess();
-
- /**
- *
- * @return
- */
- ImageTransformProcess getImageTransformProcess();
-
- /**
- *
- * @param singleCsvRecord
- * @return
- */
- SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord);
-
- SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord);
-
- /**
- *
- * @param batchCSVRecord
- * @return
- */
- BatchCSVRecord transform(BatchCSVRecord batchCSVRecord);
-
- /**
- *
- * @param batchCSVRecord
- * @return
- */
- Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord);
-
- /**
- *
- * @param singleCsvRecord
- * @return
- */
- Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord);
-
- /**
- *
- * @param singleImageRecord
- * @return
- * @throws IOException
- */
- Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException;
-
- /**
- *
- * @param batchImageRecord
- * @return
- * @throws IOException
- */
- Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException;
-
- /**
- *
- * @param singleCsvRecord
- * @return
- */
- Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
-
- /**
- *
- * @param batchCSVRecord
- * @return
- */
- Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord);
-
- /**
- *
- * @param batchCSVRecord
- * @return
- */
- SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord);
-
- /**
- *
- * @param transform
- * @return
- */
- SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform);
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java
deleted file mode 100644
index ab76b206e..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-package org.datavec.spark.transform;
-
-import lombok.extern.slf4j.Slf4j;
-import org.nd4j.common.tests.AbstractAssertTestsClass;
-import org.nd4j.common.tests.BaseND4JTest;
-
-import java.util.*;
-
-@Slf4j
-public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
-
- @Override
- protected Set> getExclusions() {
- //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
- return new HashSet<>();
- }
-
- @Override
- protected String getPackageName() {
- return "org.datavec.spark.transform";
- }
-
- @Override
- protected Class> getBaseClass() {
- return BaseND4JTest.class;
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java
deleted file mode 100644
index a5ce6c474..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/BatchCSVRecordTest.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.junit.Test;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.factory.Nd4j;
-
-import static org.junit.Assert.assertEquals;
-
-public class BatchCSVRecordTest {
-
- @Test
- public void testBatchRecordCreationFromDataSet() {
- DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
-
- BatchCSVRecord batchCSVRecord = BatchCSVRecord.fromDataSet(dataSet);
- assertEquals(2, batchCSVRecord.getRecords().size());
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java
deleted file mode 100644
index 7d1fe5f3b..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/CSVSparkTransformTest.java
+++ /dev/null
@@ -1,212 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.api.transform.transform.integer.BaseIntegerTransform;
-import org.datavec.api.transform.transform.nlp.TextToCharacterIndexTransform;
-import org.datavec.api.writable.DoubleWritable;
-import org.datavec.api.writable.Text;
-import org.datavec.api.writable.Writable;
-import org.datavec.spark.inference.model.CSVSparkTransform;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.model.SequenceBatchCSVRecord;
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.util.*;
-
-import static org.junit.Assert.*;
-
-public class CSVSparkTransformTest {
- @Test
- public void testTransformer() throws Exception {
- List input = new ArrayList<>();
- input.add(new DoubleWritable(1.0));
- input.add(new DoubleWritable(2.0));
-
- Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- List output = new ArrayList<>();
- output.add(new Text("1.0"));
- output.add(new Text("2.0"));
-
- TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
- CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
- Base64NDArrayBody body = csvSparkTransform.toArray(new SingleCSVRecord(values));
- INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
- assertTrue(fromBase64.isVector());
-// System.out.println("Base 64ed array " + fromBase64);
- }
-
- @Test
- public void testTransformerBatch() throws Exception {
- List input = new ArrayList<>();
- input.add(new DoubleWritable(1.0));
- input.add(new DoubleWritable(2.0));
-
- Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- List output = new ArrayList<>();
- output.add(new Text("1.0"));
- output.add(new Text("2.0"));
-
- TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
- CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < 3; i++)
- batchCSVRecord.add(record);
- //data type is string, unable to convert
- BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
- /* Base64NDArrayBody body = csvSparkTransform.toArray(batchCSVRecord1);
- INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
- assertTrue(fromBase64.isMatrix());
- System.out.println("Base 64ed array " + fromBase64); */
- }
-
-
-
- @Test
- public void testSingleBatchSequence() throws Exception {
- List input = new ArrayList<>();
- input.add(new DoubleWritable(1.0));
- input.add(new DoubleWritable(2.0));
-
- Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- List output = new ArrayList<>();
- output.add(new Text("1.0"));
- output.add(new Text("2.0"));
-
- TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToString("1.0").convertToString("2.0").build();
- CSVSparkTransform csvSparkTransform = new CSVSparkTransform(transformProcess);
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = csvSparkTransform.transform(new SingleCSVRecord(values));
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < 3; i++)
- batchCSVRecord.add(record);
- BatchCSVRecord batchCSVRecord1 = csvSparkTransform.transform(batchCSVRecord);
- SequenceBatchCSVRecord sequenceBatchCSVRecord = new SequenceBatchCSVRecord();
- sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
- Base64NDArrayBody sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
- INDArray outputBody = Nd4jBase64.fromBase64(sequenceArray.getNdarray());
-
-
- //ensure accumulation
- sequenceBatchCSVRecord.add(Arrays.asList(batchCSVRecord));
- sequenceArray = csvSparkTransform.transformSequenceArray(sequenceBatchCSVRecord);
- assertArrayEquals(new long[]{2,2,3},Nd4jBase64.fromBase64(sequenceArray.getNdarray()).shape());
-
- SequenceBatchCSVRecord transformed = csvSparkTransform.transformSequence(sequenceBatchCSVRecord);
- assertNotNull(transformed.getRecords());
-// System.out.println(transformed);
-
-
- }
-
- @Test
- public void testSpecificSequence() throws Exception {
- final Schema schema = new Schema.Builder()
- .addColumnsString("action")
- .build();
-
- final TransformProcess transformProcess = new TransformProcess.Builder(schema)
- .removeAllColumnsExceptFor("action")
- .transform(new ConverToLowercase("action"))
- .convertToSequence()
- .transform(new TextToCharacterIndexTransform("action", "action_sequence",
- defaultCharIndex(), false))
- .integerToOneHot("action_sequence",0,29)
- .build();
-
- final String[] data1 = new String[] { "test1" };
- final String[] data2 = new String[] { "test2" };
- final BatchCSVRecord batchCsvRecord = new BatchCSVRecord(
- Arrays.asList(
- new SingleCSVRecord(data1),
- new SingleCSVRecord(data2)));
-
- final CSVSparkTransform transform = new CSVSparkTransform(transformProcess);
-// System.out.println(transform.transformSequenceIncremental(batchCsvRecord));
- transform.transformSequenceIncremental(batchCsvRecord);
- assertEquals(3,Nd4jBase64.fromBase64(transform.transformSequenceArrayIncremental(batchCsvRecord).getNdarray()).rank());
-
- }
-
- private static Map defaultCharIndex() {
- Map ret = new TreeMap<>();
-
- ret.put('a',0);
- ret.put('b',1);
- ret.put('c',2);
- ret.put('d',3);
- ret.put('e',4);
- ret.put('f',5);
- ret.put('g',6);
- ret.put('h',7);
- ret.put('i',8);
- ret.put('j',9);
- ret.put('k',10);
- ret.put('l',11);
- ret.put('m',12);
- ret.put('n',13);
- ret.put('o',14);
- ret.put('p',15);
- ret.put('q',16);
- ret.put('r',17);
- ret.put('s',18);
- ret.put('t',19);
- ret.put('u',20);
- ret.put('v',21);
- ret.put('w',22);
- ret.put('x',23);
- ret.put('y',24);
- ret.put('z',25);
- ret.put('/',26);
- ret.put(' ',27);
- ret.put('(',28);
- ret.put(')',29);
-
- return ret;
- }
-
- public static class ConverToLowercase extends BaseIntegerTransform {
- public ConverToLowercase(String column) {
- super(column);
- }
-
- public Text map(Writable writable) {
- return new Text(writable.toString().toLowerCase());
- }
-
- public Object map(Object input) {
- return new Text(input.toString().toLowerCase());
- }
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java
deleted file mode 100644
index 415730b18..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/ImageSparkTransformTest.java
+++ /dev/null
@@ -1,86 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.ImageSparkTransform;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchImageRecord;
-import org.datavec.spark.inference.model.model.SingleImageRecord;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.io.ClassPathResource;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.File;
-
-import static org.junit.Assert.assertEquals;
-
-public class ImageSparkTransformTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
-
- @Test
- public void testSingleImageSparkTransform() throws Exception {
- int seed = 12345;
-
- File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
-
- SingleImageRecord imgRecord = new SingleImageRecord(f1.toURI());
-
- ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
- .scaleImageTransform(10).cropImageTransform(5).build();
-
- ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
- Base64NDArrayBody body = imgSparkTransform.toArray(imgRecord);
-
- INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
-// System.out.println("Base 64ed array " + fromBase64);
- assertEquals(1, fromBase64.size(0));
- }
-
- @Test
- public void testBatchImageSparkTransform() throws Exception {
- int seed = 12345;
-
- File f0 = new ClassPathResource("datavec-spark-inference/testimages/class1/A.jpg").getFile();
- File f1 = new ClassPathResource("datavec-spark-inference/testimages/class1/B.png").getFile();
- File f2 = new ClassPathResource("datavec-spark-inference/testimages/class1/C.jpg").getFile();
-
- BatchImageRecord batch = new BatchImageRecord();
- batch.add(f0.toURI());
- batch.add(f1.toURI());
- batch.add(f2.toURI());
-
- ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(seed)
- .scaleImageTransform(10).cropImageTransform(5).build();
-
- ImageSparkTransform imgSparkTransform = new ImageSparkTransform(imgTransformProcess);
- Base64NDArrayBody body = imgSparkTransform.toArray(batch);
-
- INDArray fromBase64 = Nd4jBase64.fromBase64(body.getNdarray());
-// System.out.println("Base 64ed array " + fromBase64);
- assertEquals(3, fromBase64.size(0));
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java
deleted file mode 100644
index 599f8eead..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleCSVRecordTest.java
+++ /dev/null
@@ -1,60 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.junit.Test;
-import org.nd4j.linalg.dataset.DataSet;
-import org.nd4j.linalg.factory.Nd4j;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
-public class SingleCSVRecordTest {
-
- @Test(expected = IllegalArgumentException.class)
- public void testVectorAssertion() {
- DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(1, 1));
- SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet);
- fail(singleCsvRecord.toString() + " should have thrown an exception");
- }
-
- @Test
- public void testVectorOneHotLabel() {
- DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{0, 1}, {1, 0}}));
-
- //assert
- SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
- assertEquals(3, singleCsvRecord.getValues().size());
-
- }
-
- @Test
- public void testVectorRegression() {
- DataSet dataSet = new DataSet(Nd4j.create(2, 2), Nd4j.create(new double[][] {{1, 1}, {1, 1}}));
-
- //assert
- SingleCSVRecord singleCsvRecord = SingleCSVRecord.fromRow(dataSet.get(0));
- assertEquals(4, singleCsvRecord.getValues().size());
-
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java
deleted file mode 100644
index 3c321e583..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-model/src/test/java/org/datavec/spark/transform/SingleImageRecordTest.java
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import org.datavec.spark.inference.model.model.SingleImageRecord;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-import org.nd4j.common.io.ClassPathResource;
-
-import java.io.File;
-
-public class SingleImageRecordTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
-
- @Test
- public void testImageRecord() throws Exception {
- File f = testDir.newFolder();
- new ClassPathResource("datavec-spark-inference/testimages/").copyDirectory(f);
- File f0 = new File(f, "class0/0.jpg");
- File f1 = new File(f, "/class1/A.jpg");
-
- SingleImageRecord imgRecord = new SingleImageRecord(f0.toURI());
-
- // need jackson test?
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
deleted file mode 100644
index 8a65942db..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/pom.xml
+++ /dev/null
@@ -1,154 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.datavec
- datavec-spark-inference-parent
- 1.0.0-SNAPSHOT
-
-
- datavec-spark-inference-server_2.11
-
- datavec-spark-inference-server
-
-
-
- 2.11.12
- 2.11
- 1.8
- 1.8
-
-
-
-
- org.datavec
- datavec-spark-inference-model
- ${datavec.version}
-
-
- org.datavec
- datavec-spark_2.11
- ${project.version}
-
-
- org.datavec
- datavec-data-image
-
-
- joda-time
- joda-time
-
-
- org.apache.commons
- commons-lang3
-
-
- org.hibernate
- hibernate-validator
- ${hibernate.version}
-
-
- org.scala-lang
- scala-library
- ${scala.version}
-
-
- org.scala-lang
- scala-reflect
- ${scala.version}
-
-
- com.typesafe.play
- play-java_2.11
- ${playframework.version}
-
-
- com.google.code.findbugs
- jsr305
-
-
- net.jodah
- typetools
-
-
-
-
- net.jodah
- typetools
- ${jodah.typetools.version}
-
-
- com.typesafe.play
- play-json_2.11
- ${playframework.version}
-
-
- com.typesafe.play
- play-server_2.11
- ${playframework.version}
-
-
- com.typesafe.play
- play_2.11
- ${playframework.version}
-
-
- com.typesafe.play
- play-netty-server_2.11
- ${playframework.version}
-
-
- com.typesafe.akka
- akka-cluster_2.11
- 2.5.23
-
-
- com.mashape.unirest
- unirest-java
- test
-
-
- com.beust
- jcommander
- ${jcommander.version}
-
-
- org.apache.spark
- spark-core_2.11
- ${spark.version}
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java
deleted file mode 100644
index 9ef085515..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/CSVSparkTransformServer.java
+++ /dev/null
@@ -1,352 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.server;
-
-import com.beust.jcommander.JCommander;
-import com.beust.jcommander.ParameterException;
-import lombok.Data;
-import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.CSVSparkTransform;
-import org.datavec.spark.inference.model.model.*;
-import play.BuiltInComponents;
-import play.Mode;
-import play.routing.Router;
-import play.routing.RoutingDsl;
-import play.server.Server;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.Base64;
-import java.util.Random;
-
-import static play.mvc.Results.*;
-
-@Slf4j
-@Data
-public class CSVSparkTransformServer extends SparkTransformServer {
- private CSVSparkTransform transform;
-
- public void runMain(String[] args) throws Exception {
- JCommander jcmdr = new JCommander(this);
-
- try {
- jcmdr.parse(args);
- } catch (ParameterException e) {
- //User provides invalid input -> print the usage info
- jcmdr.usage();
- if (jsonPath == null)
- System.err.println("Json path parameter is missing.");
- try {
- Thread.sleep(500);
- } catch (Exception e2) {
- }
- System.exit(1);
- }
-
- if (jsonPath != null) {
- String json = FileUtils.readFileToString(new File(jsonPath));
- TransformProcess transformProcess = TransformProcess.fromJson(json);
- transform = new CSVSparkTransform(transformProcess);
- } else {
- log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
- + "to /transformprocess");
- }
-
- //Set play secret key, if required
- //http://www.playframework.com/documentation/latest/ApplicationSecret
- String crypto = System.getProperty("play.crypto.secret");
- if (crypto == null || "changeme".equals(crypto) || "".equals(crypto) ) {
- byte[] newCrypto = new byte[1024];
-
- new Random().nextBytes(newCrypto);
-
- String base64 = Base64.getEncoder().encodeToString(newCrypto);
- System.setProperty("play.crypto.secret", base64);
- }
-
-
- server = Server.forRouter(Mode.PROD, port, this::createRouter);
- }
-
- protected Router createRouter(BuiltInComponents b){
- RoutingDsl routingDsl = RoutingDsl.fromComponents(b);
-
- routingDsl.GET("/transformprocess").routingTo(req -> {
- try {
- if (transform == null)
- return badRequest();
- return ok(transform.getTransformProcess().toJson()).as(contentType);
- } catch (Exception e) {
- log.error("Error in GET /transformprocess",e);
- return internalServerError(e.getMessage());
- }
- });
-
- routingDsl.POST("/transformprocess").routingTo(req -> {
- try {
- TransformProcess transformProcess = TransformProcess.fromJson(getJsonText(req));
- setCSVTransformProcess(transformProcess);
- log.info("Transform process initialized");
- return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
- } catch (Exception e) {
- log.error("Error in POST /transformprocess",e);
- return internalServerError(e.getMessage());
- }
- });
-
- routingDsl.POST("/transformincremental").routingTo(req -> {
- if (isSequence(req)) {
- try {
- BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
- if (record == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformSequenceIncremental(record))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformincremental", e);
- return internalServerError(e.getMessage());
- }
- } else {
- try {
- SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
- if (record == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformIncremental(record))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformincremental", e);
- return internalServerError(e.getMessage());
- }
- }
- });
-
- routingDsl.POST("/transform").routingTo(req -> {
- if (isSequence(req)) {
- try {
- SequenceBatchCSVRecord batch = transformSequence(objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class));
- if (batch == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(batch)).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transform", e);
- return internalServerError(e.getMessage());
- }
- } else {
- try {
- BatchCSVRecord input = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
- BatchCSVRecord batch = transform(input);
- if (batch == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(batch)).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transform", e);
- return internalServerError(e.getMessage());
- }
- }
- });
-
- routingDsl.POST("/transformincrementalarray").routingTo(req -> {
- if (isSequence(req)) {
- try {
- BatchCSVRecord record = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
- if (record == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformSequenceArrayIncremental(record))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformincrementalarray", e);
- return internalServerError(e.getMessage());
- }
- } else {
- try {
- SingleCSVRecord record = objectMapper.readValue(getJsonText(req), SingleCSVRecord.class);
- if (record == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformArrayIncremental(record))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformincrementalarray", e);
- return internalServerError(e.getMessage());
- }
- }
- });
-
- routingDsl.POST("/transformarray").routingTo(req -> {
- if (isSequence(req)) {
- try {
- SequenceBatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), SequenceBatchCSVRecord.class);
- if (batchCSVRecord == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformSequenceArray(batchCSVRecord))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformarray", e);
- return internalServerError(e.getMessage());
- }
- } else {
- try {
- BatchCSVRecord batchCSVRecord = objectMapper.readValue(getJsonText(req), BatchCSVRecord.class);
- if (batchCSVRecord == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformArray(batchCSVRecord))).as(contentType);
- } catch (Exception e) {
- log.error("Error in /transformarray", e);
- return internalServerError(e.getMessage());
- }
- }
- });
-
- return routingDsl.build();
- }
-
- public static void main(String[] args) throws Exception {
- new CSVSparkTransformServer().runMain(args);
- }
-
- /**
- * @param transformProcess
- */
- @Override
- public void setCSVTransformProcess(TransformProcess transformProcess) {
- this.transform = new CSVSparkTransform(transformProcess);
- }
-
- @Override
- public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
- log.error("Unsupported operation: setImageTransformProcess not supported for class", getClass());
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- /**
- * @return
- */
- @Override
- public TransformProcess getCSVTransformProcess() {
- return transform.getTransformProcess();
- }
-
- @Override
- public ImageTransformProcess getImageTransformProcess() {
- log.error("Unsupported operation: getImageTransformProcess not supported for class", getClass());
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
-
- /**
- *
- */
- /**
- * @param transform
- * @return
- */
- @Override
- public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
- return this.transform.transformSequenceIncremental(transform);
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
- return transform.transformSequence(batchCSVRecord);
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
- return this.transform.transformSequenceArray(batchCSVRecord);
- }
-
- /**
- * @param singleCsvRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
- return this.transform.transformSequenceArrayIncremental(singleCsvRecord);
- }
-
- /**
- * @param transform
- * @return
- */
- @Override
- public SingleCSVRecord transformIncremental(SingleCSVRecord transform) {
- return this.transform.transform(transform);
- }
-
- @Override
- public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
- return this.transform.transform(batchCSVRecord);
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
- return transform.transform(batchCSVRecord);
- }
-
- /**
- * @param batchCSVRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
- try {
- return this.transform.toArray(batchCSVRecord);
- } catch (IOException e) {
- log.error("Error in transformArray",e);
- throw new IllegalStateException("Transform array shouldn't throw exception");
- }
- }
-
- /**
- * @param singleCsvRecord
- * @return
- */
- @Override
- public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
- try {
- return this.transform.toArray(singleCsvRecord);
- } catch (IOException e) {
- log.error("Error in transformArrayIncremental",e);
- throw new IllegalStateException("Transform array shouldn't throw exception");
- }
- }
-
- @Override
- public Base64NDArrayBody transformIncrementalArray(SingleImageRecord singleImageRecord) throws IOException {
- log.error("Unsupported operation: transformIncrementalArray(SingleImageRecord) not supported for class", getClass());
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public Base64NDArrayBody transformArray(BatchImageRecord batchImageRecord) throws IOException {
- log.error("Unsupported operation: transformArray(BatchImageRecord) not supported for class", getClass());
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java
deleted file mode 100644
index e7744ecaa..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/ImageSparkTransformServer.java
+++ /dev/null
@@ -1,261 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.server;
-
-import com.beust.jcommander.JCommander;
-import com.beust.jcommander.ParameterException;
-import lombok.Data;
-import lombok.extern.slf4j.Slf4j;
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.model.ImageSparkTransform;
-import org.datavec.spark.inference.model.model.*;
-import play.BuiltInComponents;
-import play.Mode;
-import play.libs.Files;
-import play.mvc.Http;
-import play.routing.Router;
-import play.routing.RoutingDsl;
-import play.server.Server;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.List;
-
-import static play.mvc.Results.*;
-
-@Slf4j
-@Data
-public class ImageSparkTransformServer extends SparkTransformServer {
- private ImageSparkTransform transform;
-
- public void runMain(String[] args) throws Exception {
- JCommander jcmdr = new JCommander(this);
-
- try {
- jcmdr.parse(args);
- } catch (ParameterException e) {
- //User provides invalid input -> print the usage info
- jcmdr.usage();
- if (jsonPath == null)
- System.err.println("Json path parameter is missing.");
- try {
- Thread.sleep(500);
- } catch (Exception e2) {
- }
- System.exit(1);
- }
-
- if (jsonPath != null) {
- String json = FileUtils.readFileToString(new File(jsonPath));
- ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(json);
- transform = new ImageSparkTransform(transformProcess);
- } else {
- log.warn("Server started with no json for transform process. Please ensure you specify a transform process via sending a post request with raw json"
- + "to /transformprocess");
- }
-
- server = Server.forRouter(Mode.PROD, port, this::createRouter);
- }
-
- protected Router createRouter(BuiltInComponents builtInComponents){
- RoutingDsl routingDsl = RoutingDsl.fromComponents(builtInComponents);
-
- routingDsl.GET("/transformprocess").routingTo(req -> {
- try {
- if (transform == null)
- return badRequest();
- log.info("Transform process initialized");
- return ok(objectMapper.writeValueAsString(transform.getImageTransformProcess())).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- routingDsl.POST("/transformprocess").routingTo(req -> {
- try {
- ImageTransformProcess transformProcess = ImageTransformProcess.fromJson(getJsonText(req));
- setImageTransformProcess(transformProcess);
- log.info("Transform process initialized");
- return ok(objectMapper.writeValueAsString(transformProcess)).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- routingDsl.POST("/transformincrementalarray").routingTo(req -> {
- try {
- SingleImageRecord record = objectMapper.readValue(getJsonText(req), SingleImageRecord.class);
- if (record == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- routingDsl.POST("/transformincrementalimage").routingTo(req -> {
- try {
- Http.MultipartFormData body = req.body().asMultipartFormData();
- List> files = body.getFiles();
- if (files.isEmpty() || files.get(0).getRef() == null ) {
- return badRequest();
- }
-
- File file = files.get(0).getRef().path().toFile();
- SingleImageRecord record = new SingleImageRecord(file.toURI());
-
- return ok(objectMapper.writeValueAsString(transformIncrementalArray(record))).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- routingDsl.POST("/transformarray").routingTo(req -> {
- try {
- BatchImageRecord batch = objectMapper.readValue(getJsonText(req), BatchImageRecord.class);
- if (batch == null)
- return badRequest();
- return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- routingDsl.POST("/transformimage").routingTo(req -> {
- try {
- Http.MultipartFormData body = req.body().asMultipartFormData();
- List> files = body.getFiles();
- if (files.size() == 0) {
- return badRequest();
- }
-
- List records = new ArrayList<>();
-
- for (Http.MultipartFormData.FilePart filePart : files) {
- Files.TemporaryFile file = filePart.getRef();
- if (file != null) {
- SingleImageRecord record = new SingleImageRecord(file.path().toUri());
- records.add(record);
- }
- }
-
- BatchImageRecord batch = new BatchImageRecord(records);
-
- return ok(objectMapper.writeValueAsString(transformArray(batch))).as(contentType);
- } catch (Exception e) {
- log.error("",e);
- return internalServerError();
- }
- });
-
- return routingDsl.build();
- }
-
- @Override
- public Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord) {
- throw new UnsupportedOperationException();
- }
-
- @Override
- public Base64NDArrayBody transformSequenceArray(SequenceBatchCSVRecord batchCSVRecord) {
- throw new UnsupportedOperationException();
-
- }
-
- @Override
- public SequenceBatchCSVRecord transformSequence(SequenceBatchCSVRecord batchCSVRecord) {
- throw new UnsupportedOperationException();
-
- }
-
- @Override
- public SequenceBatchCSVRecord transformSequenceIncremental(BatchCSVRecord transform) {
- throw new UnsupportedOperationException();
-
- }
-
- @Override
- public void setCSVTransformProcess(TransformProcess transformProcess) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public void setImageTransformProcess(ImageTransformProcess imageTransformProcess) {
- this.transform = new ImageSparkTransform(imageTransformProcess);
- }
-
- @Override
- public TransformProcess getCSVTransformProcess() {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public ImageTransformProcess getImageTransformProcess() {
- return transform.getImageTransformProcess();
- }
-
- @Override
- public SingleCSVRecord transformIncremental(SingleCSVRecord singleCsvRecord) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public SequenceBatchCSVRecord transform(SequenceBatchCSVRecord batchCSVRecord) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public BatchCSVRecord transform(BatchCSVRecord batchCSVRecord) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public Base64NDArrayBody transformArray(BatchCSVRecord batchCSVRecord) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public Base64NDArrayBody transformArrayIncremental(SingleCSVRecord singleCsvRecord) {
- throw new UnsupportedOperationException("Invalid operation for " + this.getClass());
- }
-
- @Override
- public Base64NDArrayBody transformIncrementalArray(SingleImageRecord record) throws IOException {
- return transform.toArray(record);
- }
-
- @Override
- public Base64NDArrayBody transformArray(BatchImageRecord batch) throws IOException {
- return transform.toArray(batch);
- }
-
- public static void main(String[] args) throws Exception {
- new ImageSparkTransformServer().runMain(args);
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java
deleted file mode 100644
index c89ef90cc..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServer.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.server;
-
-import com.beust.jcommander.Parameter;
-import com.fasterxml.jackson.databind.JsonNode;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.service.DataVecTransformService;
-import org.nd4j.shade.jackson.databind.ObjectMapper;
-import play.mvc.Http;
-import play.server.Server;
-
-public abstract class SparkTransformServer implements DataVecTransformService {
- @Parameter(names = {"-j", "--jsonPath"}, arity = 1)
- protected String jsonPath = null;
- @Parameter(names = {"-dp", "--dataVecPort"}, arity = 1)
- protected int port = 9000;
- @Parameter(names = {"-dt", "--dataType"}, arity = 1)
- private TransformDataType transformDataType = null;
- protected Server server;
- protected static ObjectMapper objectMapper = new ObjectMapper();
- protected static String contentType = "application/json";
-
- public abstract void runMain(String[] args) throws Exception;
-
- /**
- * Stop the server
- */
- public void stop() {
- if (server != null)
- server.stop();
- }
-
- protected boolean isSequence(Http.Request request) {
- return request.hasHeader(SEQUENCE_OR_NOT_HEADER)
- && request.header(SEQUENCE_OR_NOT_HEADER).get().equalsIgnoreCase("true");
- }
-
- protected String getJsonText(Http.Request request) {
- JsonNode tryJson = request.body().asJson();
- if (tryJson != null)
- return tryJson.toString();
- else
- return request.body().asText();
- }
-
- public abstract Base64NDArrayBody transformSequenceArrayIncremental(BatchCSVRecord singleCsvRecord);
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java
deleted file mode 100644
index aa4945ddb..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/SparkTransformServerChooser.java
+++ /dev/null
@@ -1,76 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.server;
-
-import lombok.Data;
-import lombok.extern.slf4j.Slf4j;
-
-import java.io.InvalidClassException;
-import java.util.Arrays;
-import java.util.List;
-
-@Data
-@Slf4j
-public class SparkTransformServerChooser {
- private SparkTransformServer sparkTransformServer = null;
- private TransformDataType transformDataType = null;
-
- public void runMain(String[] args) throws Exception {
-
- int pos = getMatchingPosition(args, "-dt", "--dataType");
- if (pos == -1) {
- log.error("no valid options");
- log.error("-dt, --dataType Options: [CSV, IMAGE]");
- throw new Exception("no valid options");
- } else {
- transformDataType = TransformDataType.valueOf(args[pos + 1]);
- }
-
- switch (transformDataType) {
- case CSV:
- sparkTransformServer = new CSVSparkTransformServer();
- break;
- case IMAGE:
- sparkTransformServer = new ImageSparkTransformServer();
- break;
- default:
- throw new InvalidClassException("no matching SparkTransform class");
- }
-
- sparkTransformServer.runMain(args);
- }
-
- private int getMatchingPosition(String[] args, String... options) {
- List optionList = Arrays.asList(options);
-
- for (int i = 0; i < args.length; i++) {
- if (optionList.contains(args[i])) {
- return i;
- }
- }
- return -1;
- }
-
-
- public static void main(String[] args) throws Exception {
- new SparkTransformServerChooser().runMain(args);
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java
deleted file mode 100644
index 643cd5652..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/java/org/datavec/spark/inference/server/TransformDataType.java
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.inference.server;
-
-public enum TransformDataType {
- CSV, IMAGE,
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf
deleted file mode 100644
index 28a4aa208..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/main/resources/application.conf
+++ /dev/null
@@ -1,350 +0,0 @@
-# This is the main configuration file for the application.
-# https://www.playframework.com/documentation/latest/ConfigFile
-# ~~~~~
-# Play uses HOCON as its configuration file format. HOCON has a number
-# of advantages over other config formats, but there are two things that
-# can be used when modifying settings.
-#
-# You can include other configuration files in this main application.conf file:
-#include "extra-config.conf"
-#
-# You can declare variables and substitute for them:
-#mykey = ${some.value}
-#
-# And if an environment variable exists when there is no other subsitution, then
-# HOCON will fall back to substituting environment variable:
-#mykey = ${JAVA_HOME}
-
-## Akka
-# https://www.playframework.com/documentation/latest/ScalaAkka#Configuration
-# https://www.playframework.com/documentation/latest/JavaAkka#Configuration
-# ~~~~~
-# Play uses Akka internally and exposes Akka Streams and actors in Websockets and
-# other streaming HTTP responses.
-akka {
- # "akka.log-config-on-start" is extraordinarly useful because it log the complete
- # configuration at INFO level, including defaults and overrides, so it s worth
- # putting at the very top.
- #
- # Put the following in your conf/logback.xml file:
- #
- #
- #
- # And then uncomment this line to debug the configuration.
- #
- #log-config-on-start = true
-}
-
-## Modules
-# https://www.playframework.com/documentation/latest/Modules
-# ~~~~~
-# Control which modules are loaded when Play starts. Note that modules are
-# the replacement for "GlobalSettings", which are deprecated in 2.5.x.
-# Please see https://www.playframework.com/documentation/latest/GlobalSettings
-# for more information.
-#
-# You can also extend Play functionality by using one of the publically available
-# Play modules: https://playframework.com/documentation/latest/ModuleDirectory
-play.modules {
- # By default, Play will load any class called Module that is defined
- # in the root package (the "app" directory), or you can define them
- # explicitly below.
- # If there are any built-in modules that you want to disable, you can list them here.
- #enabled += my.application.Module
-
- # If there are any built-in modules that you want to disable, you can list them here.
- #disabled += ""
-}
-
-## Internationalisation
-# https://www.playframework.com/documentation/latest/JavaI18N
-# https://www.playframework.com/documentation/latest/ScalaI18N
-# ~~~~~
-# Play comes with its own i18n settings, which allow the user's preferred language
-# to map through to internal messages, or allow the language to be stored in a cookie.
-play.i18n {
- # The application languages
- langs = [ "en" ]
-
- # Whether the language cookie should be secure or not
- #langCookieSecure = true
-
- # Whether the HTTP only attribute of the cookie should be set to true
- #langCookieHttpOnly = true
-}
-
-## Play HTTP settings
-# ~~~~~
-play.http {
- ## Router
- # https://www.playframework.com/documentation/latest/JavaRouting
- # https://www.playframework.com/documentation/latest/ScalaRouting
- # ~~~~~
- # Define the Router object to use for this application.
- # This router will be looked up first when the application is starting up,
- # so make sure this is the entry point.
- # Furthermore, it's assumed your route file is named properly.
- # So for an application router like `my.application.Router`,
- # you may need to define a router file `conf/my.application.routes`.
- # Default to Routes in the root package (aka "apps" folder) (and conf/routes)
- #router = my.application.Router
-
- ## Action Creator
- # https://www.playframework.com/documentation/latest/JavaActionCreator
- # ~~~~~
- #actionCreator = null
-
- ## ErrorHandler
- # https://www.playframework.com/documentation/latest/JavaRouting
- # https://www.playframework.com/documentation/latest/ScalaRouting
- # ~~~~~
- # If null, will attempt to load a class called ErrorHandler in the root package,
- #errorHandler = null
-
- ## Filters
- # https://www.playframework.com/documentation/latest/ScalaHttpFilters
- # https://www.playframework.com/documentation/latest/JavaHttpFilters
- # ~~~~~
- # Filters run code on every request. They can be used to perform
- # common logic for all your actions, e.g. adding common headers.
- # Defaults to "Filters" in the root package (aka "apps" folder)
- # Alternatively you can explicitly register a class here.
- #filters += my.application.Filters
-
- ## Session & Flash
- # https://www.playframework.com/documentation/latest/JavaSessionFlash
- # https://www.playframework.com/documentation/latest/ScalaSessionFlash
- # ~~~~~
- session {
- # Sets the cookie to be sent only over HTTPS.
- #secure = true
-
- # Sets the cookie to be accessed only by the server.
- #httpOnly = true
-
- # Sets the max-age field of the cookie to 5 minutes.
- # NOTE: this only sets when the browser will discard the cookie. Play will consider any
- # cookie value with a valid signature to be a valid session forever. To implement a server side session timeout,
- # you need to put a timestamp in the session and check it at regular intervals to possibly expire it.
- #maxAge = 300
-
- # Sets the domain on the session cookie.
- #domain = "example.com"
- }
-
- flash {
- # Sets the cookie to be sent only over HTTPS.
- #secure = true
-
- # Sets the cookie to be accessed only by the server.
- #httpOnly = true
- }
-}
-
-## Netty Provider
-# https://www.playframework.com/documentation/latest/SettingsNetty
-# ~~~~~
-play.server.netty {
- # Whether the Netty wire should be logged
- #log.wire = true
-
- # If you run Play on Linux, you can use Netty's native socket transport
- # for higher performance with less garbage.
- #transport = "native"
-}
-
-## WS (HTTP Client)
-# https://www.playframework.com/documentation/latest/ScalaWS#Configuring-WS
-# ~~~~~
-# The HTTP client primarily used for REST APIs. The default client can be
-# configured directly, but you can also create different client instances
-# with customized settings. You must enable this by adding to build.sbt:
-#
-# libraryDependencies += ws // or javaWs if using java
-#
-play.ws {
- # Sets HTTP requests not to follow 302 requests
- #followRedirects = false
-
- # Sets the maximum number of open HTTP connections for the client.
- #ahc.maxConnectionsTotal = 50
-
- ## WS SSL
- # https://www.playframework.com/documentation/latest/WsSSL
- # ~~~~~
- ssl {
- # Configuring HTTPS with Play WS does not require programming. You can
- # set up both trustManager and keyManager for mutual authentication, and
- # turn on JSSE debugging in development with a reload.
- #debug.handshake = true
- #trustManager = {
- # stores = [
- # { type = "JKS", path = "exampletrust.jks" }
- # ]
- #}
- }
-}
-
-## Cache
-# https://www.playframework.com/documentation/latest/JavaCache
-# https://www.playframework.com/documentation/latest/ScalaCache
-# ~~~~~
-# Play comes with an integrated cache API that can reduce the operational
-# overhead of repeated requests. You must enable this by adding to build.sbt:
-#
-# libraryDependencies += cache
-#
-play.cache {
- # If you want to bind several caches, you can bind the individually
- #bindCaches = ["db-cache", "user-cache", "session-cache"]
-}
-
-## Filters
-# https://www.playframework.com/documentation/latest/Filters
-# ~~~~~
-# There are a number of built-in filters that can be enabled and configured
-# to give Play greater security. You must enable this by adding to build.sbt:
-#
-# libraryDependencies += filters
-#
-play.filters {
- ## CORS filter configuration
- # https://www.playframework.com/documentation/latest/CorsFilter
- # ~~~~~
- # CORS is a protocol that allows web applications to make requests from the browser
- # across different domains.
- # NOTE: You MUST apply the CORS configuration before the CSRF filter, as CSRF has
- # dependencies on CORS settings.
- cors {
- # Filter paths by a whitelist of path prefixes
- #pathPrefixes = ["/some/path", ...]
-
- # The allowed origins. If null, all origins are allowed.
- #allowedOrigins = ["http://www.example.com"]
-
- # The allowed HTTP methods. If null, all methods are allowed
- #allowedHttpMethods = ["GET", "POST"]
- }
-
- ## CSRF Filter
- # https://www.playframework.com/documentation/latest/ScalaCsrf#Applying-a-global-CSRF-filter
- # https://www.playframework.com/documentation/latest/JavaCsrf#Applying-a-global-CSRF-filter
- # ~~~~~
- # Play supports multiple methods for verifying that a request is not a CSRF request.
- # The primary mechanism is a CSRF token. This token gets placed either in the query string
- # or body of every form submitted, and also gets placed in the users session.
- # Play then verifies that both tokens are present and match.
- csrf {
- # Sets the cookie to be sent only over HTTPS
- #cookie.secure = true
-
- # Defaults to CSRFErrorHandler in the root package.
- #errorHandler = MyCSRFErrorHandler
- }
-
- ## Security headers filter configuration
- # https://www.playframework.com/documentation/latest/SecurityHeaders
- # ~~~~~
- # Defines security headers that prevent XSS attacks.
- # If enabled, then all options are set to the below configuration by default:
- headers {
- # The X-Frame-Options header. If null, the header is not set.
- #frameOptions = "DENY"
-
- # The X-XSS-Protection header. If null, the header is not set.
- #xssProtection = "1; mode=block"
-
- # The X-Content-Type-Options header. If null, the header is not set.
- #contentTypeOptions = "nosniff"
-
- # The X-Permitted-Cross-Domain-Policies header. If null, the header is not set.
- #permittedCrossDomainPolicies = "master-only"
-
- # The Content-Security-Policy header. If null, the header is not set.
- #contentSecurityPolicy = "default-src 'self'"
- }
-
- ## Allowed hosts filter configuration
- # https://www.playframework.com/documentation/latest/AllowedHostsFilter
- # ~~~~~
- # Play provides a filter that lets you configure which hosts can access your application.
- # This is useful to prevent cache poisoning attacks.
- hosts {
- # Allow requests to example.com, its subdomains, and localhost:9000.
- #allowed = [".example.com", "localhost:9000"]
- }
-}
-
-## Evolutions
-# https://www.playframework.com/documentation/latest/Evolutions
-# ~~~~~
-# Evolutions allows database scripts to be automatically run on startup in dev mode
-# for database migrations. You must enable this by adding to build.sbt:
-#
-# libraryDependencies += evolutions
-#
-play.evolutions {
- # You can disable evolutions for a specific datasource if necessary
- #db.default.enabled = false
-}
-
-## Database Connection Pool
-# https://www.playframework.com/documentation/latest/SettingsJDBC
-# ~~~~~
-# Play doesn't require a JDBC database to run, but you can easily enable one.
-#
-# libraryDependencies += jdbc
-#
-play.db {
- # The combination of these two settings results in "db.default" as the
- # default JDBC pool:
- #config = "db"
- #default = "default"
-
- # Play uses HikariCP as the default connection pool. You can override
- # settings by changing the prototype:
- prototype {
- # Sets a fixed JDBC connection pool size of 50
- #hikaricp.minimumIdle = 50
- #hikaricp.maximumPoolSize = 50
- }
-}
-
-## JDBC Datasource
-# https://www.playframework.com/documentation/latest/JavaDatabase
-# https://www.playframework.com/documentation/latest/ScalaDatabase
-# ~~~~~
-# Once JDBC datasource is set up, you can work with several different
-# database options:
-#
-# Slick (Scala preferred option): https://www.playframework.com/documentation/latest/PlaySlick
-# JPA (Java preferred option): https://playframework.com/documentation/latest/JavaJPA
-# EBean: https://playframework.com/documentation/latest/JavaEbean
-# Anorm: https://www.playframework.com/documentation/latest/ScalaAnorm
-#
-db {
- # You can declare as many datasources as you want.
- # By convention, the default datasource is named `default`
-
- # https://www.playframework.com/documentation/latest/Developing-with-the-H2-Database
- default.driver = org.h2.Driver
- default.url = "jdbc:h2:mem:play"
- #default.username = sa
- #default.password = ""
-
- # You can expose this datasource via JNDI if needed (Useful for JPA)
- default.jndiName=DefaultDS
-
- # You can turn on SQL logging for any datasource
- # https://www.playframework.com/documentation/latest/Highlights25#Logging-SQL-statements
- #default.logSql=true
-}
-
-jpa.default=defaultPersistenceUnit
-
-
-#Increase default maximum post length - used for remote listener functionality
-#Can get response 413 with larger networks without setting this
-# parsers.text.maxLength is deprecated, use play.http.parser.maxMemoryBuffer instead
-#parsers.text.maxLength=10M
-play.http.parser.maxMemoryBuffer=10M
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java
deleted file mode 100644
index ab76b206e..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/AssertTestsExtendBaseClass.java
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-package org.datavec.spark.transform;
-
-import lombok.extern.slf4j.Slf4j;
-import org.nd4j.common.tests.AbstractAssertTestsClass;
-import org.nd4j.common.tests.BaseND4JTest;
-
-import java.util.*;
-
-@Slf4j
-public class AssertTestsExtendBaseClass extends AbstractAssertTestsClass {
-
- @Override
- protected Set> getExclusions() {
- //Set of classes that are exclusions to the rule (either run manually or have their own logging + timeouts)
- return new HashSet<>();
- }
-
- @Override
- protected String getPackageName() {
- return "org.datavec.spark.transform";
- }
-
- @Override
- protected Class> getBaseClass() {
- return BaseND4JTest.class;
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java
deleted file mode 100644
index 8f309caff..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerNoJsonTest.java
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-import com.mashape.unirest.http.JsonNode;
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.spark.inference.server.CSVSparkTransformServer;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.UUID;
-
-import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeNotNull;
-
-public class CSVSparkTransformServerNoJsonTest {
-
- private static CSVSparkTransformServer server;
- private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- private static TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
- private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
-
- @BeforeClass
- public static void before() throws Exception {
- server = new CSVSparkTransformServer();
- FileUtils.write(fileSave, transformProcess.toJson());
-
- // Only one time
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
- });
-
- server.runMain(new String[] {"-dp", "9050"});
- }
-
- @AfterClass
- public static void after() throws Exception {
- fileSave.delete();
- server.stop();
-
- }
-
-
-
- @Test
- public void testServer() throws Exception {
- assertTrue(server.getTransform() == null);
- JsonNode jsonStatus = Unirest.post("http://localhost:9050/transformprocess")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(transformProcess.toJson()).asJson().getBody();
- assumeNotNull(server.getTransform());
-
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = new SingleCSVRecord(values);
- JsonNode jsonNode =
- Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
- .header("Content-Type", "application/json").body(record).asJson().getBody();
- SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(SingleCSVRecord.class).getBody();
-
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < 3; i++)
- batchCSVRecord.add(singleCsvRecord);
- /* BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
-
- Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(Base64NDArrayBody.class).getBody();
-*/
- Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
-
-
-
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java
deleted file mode 100644
index a3af5f2c6..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/CSVSparkTransformServerTest.java
+++ /dev/null
@@ -1,121 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-
-import com.mashape.unirest.http.JsonNode;
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.spark.inference.server.CSVSparkTransformServer;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchCSVRecord;
-import org.datavec.spark.inference.model.model.SingleCSVRecord;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.UUID;
-
-public class CSVSparkTransformServerTest {
-
- private static CSVSparkTransformServer server;
- private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- private static TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble("2.0").build();
- private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
-
- @BeforeClass
- public static void before() throws Exception {
- server = new CSVSparkTransformServer();
- FileUtils.write(fileSave, transformProcess.toJson());
- // Only one time
-
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
- });
-
- server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9050"});
- }
-
- @AfterClass
- public static void after() throws Exception {
- fileSave.deleteOnExit();
- server.stop();
-
- }
-
-
-
- @Test
- public void testServer() throws Exception {
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = new SingleCSVRecord(values);
- JsonNode jsonNode =
- Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
- .header("Content-Type", "application/json").body(record).asJson().getBody();
- SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(SingleCSVRecord.class).getBody();
-
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < 3; i++)
- batchCSVRecord.add(singleCsvRecord);
- BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
-
- Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(Base64NDArrayBody.class).getBody();
-
- Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
-
-
-
-
-
- }
-
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java
deleted file mode 100644
index 12f754acd..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/ImageSparkTransformServerTest.java
+++ /dev/null
@@ -1,164 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-
-import com.mashape.unirest.http.JsonNode;
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import org.apache.commons.io.FileUtils;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.server.ImageSparkTransformServer;
-import org.datavec.spark.inference.model.model.Base64NDArrayBody;
-import org.datavec.spark.inference.model.model.BatchImageRecord;
-import org.datavec.spark.inference.model.model.SingleImageRecord;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.TemporaryFolder;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.io.ClassPathResource;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.UUID;
-
-import static org.junit.Assert.assertEquals;
-
-public class ImageSparkTransformServerTest {
-
- @Rule
- public TemporaryFolder testDir = new TemporaryFolder();
-
- private static ImageSparkTransformServer server;
- private static File fileSave = new File(UUID.randomUUID().toString() + ".json");
-
- @BeforeClass
- public static void before() throws Exception {
- server = new ImageSparkTransformServer();
-
- ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
- .scaleImageTransform(10).cropImageTransform(5).build();
-
- FileUtils.write(fileSave, imgTransformProcess.toJson());
-
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
- });
-
- server.runMain(new String[] {"--jsonPath", fileSave.getAbsolutePath(), "-dp", "9060"});
- }
-
- @AfterClass
- public static void after() throws Exception {
- fileSave.deleteOnExit();
- server.stop();
-
- }
-
- @Test
- public void testImageServer() throws Exception {
- SingleImageRecord record =
- new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
- JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asJson().getBody();
- Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(Base64NDArrayBody.class).getBody();
-
- BatchImageRecord batch = new BatchImageRecord();
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
-
- JsonNode jsonNodeBatch =
- Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
- .header("Content-Type", "application/json").body(batch).asJson().getBody();
- Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(batch)
- .asObject(Base64NDArrayBody.class).getBody();
-
- INDArray result = getNDArray(jsonNode);
- assertEquals(1, result.size(0));
-
- INDArray batchResult = getNDArray(jsonNodeBatch);
- assertEquals(3, batchResult.size(0));
-
-// System.out.println(array);
- }
-
- @Test
- public void testImageServerMultipart() throws Exception {
- JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
- .header("accept", "application/json")
- .field("file1", new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile())
- .field("file2", new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile())
- .field("file3", new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile())
- .asJson().getBody();
-
-
- INDArray batchResult = getNDArray(jsonNode);
- assertEquals(3, batchResult.size(0));
-
-// System.out.println(batchResult);
- }
-
- @Test
- public void testImageServerSingleMultipart() throws Exception {
- File f = testDir.newFolder();
- File imgFile = new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getTempFileFromArchive(f);
-
- JsonNode jsonNode = Unirest.post("http://localhost:9060/transformimage")
- .header("accept", "application/json")
- .field("file1", imgFile)
- .asJson().getBody();
-
-
- INDArray result = getNDArray(jsonNode);
- assertEquals(1, result.size(0));
-
-// System.out.println(result);
- }
-
- public INDArray getNDArray(JsonNode node) throws IOException {
- return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java
deleted file mode 100644
index 831dd24f4..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/java/org/datavec/spark/transform/SparkTransformServerTest.java
+++ /dev/null
@@ -1,168 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.datavec.spark.transform;
-
-
-import com.mashape.unirest.http.JsonNode;
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import org.apache.commons.io.FileUtils;
-import org.datavec.api.transform.TransformProcess;
-import org.datavec.api.transform.schema.Schema;
-import org.datavec.image.transform.ImageTransformProcess;
-import org.datavec.spark.inference.server.SparkTransformServerChooser;
-import org.datavec.spark.inference.server.TransformDataType;
-import org.datavec.spark.inference.model.model.*;
-import org.junit.AfterClass;
-import org.junit.BeforeClass;
-import org.junit.Test;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.io.ClassPathResource;
-import org.nd4j.serde.base64.Nd4jBase64;
-
-import java.io.File;
-import java.io.IOException;
-import java.util.UUID;
-
-import static org.junit.Assert.assertEquals;
-
-public class SparkTransformServerTest {
- private static SparkTransformServerChooser serverChooser;
- private static Schema schema = new Schema.Builder().addColumnDouble("1.0").addColumnDouble("2.0").build();
- private static TransformProcess transformProcess =
- new TransformProcess.Builder(schema).convertToDouble("1.0").convertToDouble( "2.0").build();
-
- private static File imageTransformFile = new File(UUID.randomUUID().toString() + ".json");
- private static File csvTransformFile = new File(UUID.randomUUID().toString() + ".json");
-
- @BeforeClass
- public static void before() throws Exception {
- serverChooser = new SparkTransformServerChooser();
-
- ImageTransformProcess imgTransformProcess = new ImageTransformProcess.Builder().seed(12345)
- .scaleImageTransform(10).cropImageTransform(5).build();
-
- FileUtils.write(imageTransformFile, imgTransformProcess.toJson());
-
- FileUtils.write(csvTransformFile, transformProcess.toJson());
-
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }
- });
-
-
- }
-
- @AfterClass
- public static void after() throws Exception {
- imageTransformFile.deleteOnExit();
- csvTransformFile.deleteOnExit();
- }
-
- @Test
- public void testImageServer() throws Exception {
- serverChooser.runMain(new String[] {"--jsonPath", imageTransformFile.getAbsolutePath(), "-dp", "9060", "-dt",
- TransformDataType.IMAGE.toString()});
-
- SingleImageRecord record =
- new SingleImageRecord(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
- JsonNode jsonNode = Unirest.post("http://localhost:9060/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asJson().getBody();
- Base64NDArrayBody array = Unirest.post("http://localhost:9060/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(Base64NDArrayBody.class).getBody();
-
- BatchImageRecord batch = new BatchImageRecord();
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/0.jpg").getFile().toURI());
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/1.png").getFile().toURI());
- batch.add(new ClassPathResource("datavec-spark-inference/testimages/class0/2.jpg").getFile().toURI());
-
- JsonNode jsonNodeBatch =
- Unirest.post("http://localhost:9060/transformarray").header("accept", "application/json")
- .header("Content-Type", "application/json").body(batch).asJson().getBody();
- Base64NDArrayBody batchArray = Unirest.post("http://localhost:9060/transformarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(batch)
- .asObject(Base64NDArrayBody.class).getBody();
-
- INDArray result = getNDArray(jsonNode);
- assertEquals(1, result.size(0));
-
- INDArray batchResult = getNDArray(jsonNodeBatch);
- assertEquals(3, batchResult.size(0));
-
- serverChooser.getSparkTransformServer().stop();
- }
-
- @Test
- public void testCSVServer() throws Exception {
- serverChooser.runMain(new String[] {"--jsonPath", csvTransformFile.getAbsolutePath(), "-dp", "9050", "-dt",
- TransformDataType.CSV.toString()});
-
- String[] values = new String[] {"1.0", "2.0"};
- SingleCSVRecord record = new SingleCSVRecord(values);
- JsonNode jsonNode =
- Unirest.post("http://localhost:9050/transformincremental").header("accept", "application/json")
- .header("Content-Type", "application/json").body(record).asJson().getBody();
- SingleCSVRecord singleCsvRecord = Unirest.post("http://localhost:9050/transformincremental")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(SingleCSVRecord.class).getBody();
-
- BatchCSVRecord batchCSVRecord = new BatchCSVRecord();
- for (int i = 0; i < 3; i++)
- batchCSVRecord.add(singleCsvRecord);
- BatchCSVRecord batchCSVRecord1 = Unirest.post("http://localhost:9050/transform")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(BatchCSVRecord.class).getBody();
-
- Base64NDArrayBody array = Unirest.post("http://localhost:9050/transformincrementalarray")
- .header("accept", "application/json").header("Content-Type", "application/json").body(record)
- .asObject(Base64NDArrayBody.class).getBody();
-
- Base64NDArrayBody batchArray1 = Unirest.post("http://localhost:9050/transformarray")
- .header("accept", "application/json").header("Content-Type", "application/json")
- .body(batchCSVRecord).asObject(Base64NDArrayBody.class).getBody();
-
-
- serverChooser.getSparkTransformServer().stop();
- }
-
- public INDArray getNDArray(JsonNode node) throws IOException {
- return Nd4jBase64.fromBase64(node.getObject().getString("ndarray"));
- }
-}
diff --git a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf b/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf
deleted file mode 100644
index dbac92d83..000000000
--- a/datavec/datavec-spark-inference-parent/datavec-spark-inference-server/src/test/resources/application.conf
+++ /dev/null
@@ -1,6 +0,0 @@
-play.modules.enabled += com.lightbend.lagom.discovery.zookeeper.ZooKeeperServiceLocatorModule
-play.modules.enabled += io.skymind.skil.service.PredictionModule
-play.crypto.secret = as8dufasdfuasdfjkasdkfalksjfk
-play.server.pidfile.path=/tmp/RUNNING_PID
-
-play.server.http.port = 9600
diff --git a/datavec/datavec-spark-inference-parent/pom.xml b/datavec/datavec-spark-inference-parent/pom.xml
deleted file mode 100644
index abf3f3b0d..000000000
--- a/datavec/datavec-spark-inference-parent/pom.xml
+++ /dev/null
@@ -1,68 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.datavec
- datavec-parent
- 1.0.0-SNAPSHOT
-
-
- datavec-spark-inference-parent
- pom
-
- datavec-spark-inference-parent
-
-
- datavec-spark-inference-server
- datavec-spark-inference-client
- datavec-spark-inference-model
-
-
-
-
-
- org.datavec
- datavec-data-image
- ${datavec.version}
-
-
- com.mashape.unirest
- unirest-java
- ${unirest.version}
-
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/datavec/pom.xml b/datavec/pom.xml
index 4142db170..d1c46077f 100644
--- a/datavec/pom.xml
+++ b/datavec/pom.xml
@@ -45,7 +45,6 @@
datavec-data
datavec-spark
datavec-local
- datavec-spark-inference-parent
datavec-jdbc
datavec-excel
datavec-arrow
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
deleted file mode 100644
index ee029d09f..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/pom.xml
+++ /dev/null
@@ -1,143 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-parent
- 1.0.0-SNAPSHOT
-
-
- deeplearning4j-nearestneighbor-server
- jar
-
- deeplearning4j-nearestneighbor-server
-
-
- 1.8
-
-
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-model
- ${project.version}
-
-
- org.deeplearning4j
- deeplearning4j-core
- ${project.version}
-
-
- io.vertx
- vertx-core
- ${vertx.version}
-
-
- io.vertx
- vertx-web
- ${vertx.version}
-
-
- com.mashape.unirest
- unirest-java
- ${unirest.version}
- test
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-client
- ${project.version}
- test
-
-
- com.beust
- jcommander
- ${jcommander.version}
-
-
- ch.qos.logback
- logback-classic
- test
-
-
- org.deeplearning4j
- deeplearning4j-common-tests
- ${project.version}
- test
-
-
-
-
-
-
- org.apache.maven.plugins
- maven-surefire-plugin
-
- -Dfile.encoding=UTF-8 -Xmx8g
-
-
- *.java
- **/*.java
-
-
-
-
- org.apache.maven.plugins
- maven-compiler-plugin
-
-
- ${java.compile.version}
-
-
-
-
-
-
-
- test-nd4j-native
-
-
- org.nd4j
- nd4j-native
- ${project.version}
- test
-
-
-
-
- test-nd4j-cuda-11.0
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
- test
-
-
-
-
-
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java
deleted file mode 100644
index 88f3a7b46..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/main/java/org/deeplearning4j/nearestneighbor/server/NearestNeighbor.java
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.server;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import org.deeplearning4j.clustering.sptree.DataPoint;
-import org.deeplearning4j.clustering.vptree.VPTree;
-import org.deeplearning4j.nearestneighbor.model.NearestNeighborRequest;
-import org.deeplearning4j.nearestneighbor.model.NearestNeighborsResult;
-import org.nd4j.linalg.api.ndarray.INDArray;
-
-import java.util.ArrayList;
-import java.util.List;
-
-@AllArgsConstructor
-@Builder
-public class NearestNeighbor {
- private NearestNeighborRequest record;
- private VPTree tree;
- private INDArray points;
-
- public List search() {
- INDArray input = points.slice(record.getInputIndex());
- List results = new ArrayList<>();
- if (input.isVector()) {
- List add = new ArrayList<>();
- List distances = new ArrayList<>();
- tree.search(input, record.getK(), add, distances);
-
- if (add.size() != distances.size()) {
- throw new IllegalStateException(
- String.format("add.size == %d != %d == distances.size",
- add.size(), distances.size()));
- }
-
- for (int i=0; i print the usage info
- jcmdr.usage();
- if (r.ndarrayPath == null)
- log.error("Json path parameter is missing (null)");
- try {
- Thread.sleep(500);
- } catch (Exception e2) {
- }
- System.exit(1);
- }
-
- instanceArgs = r;
- try {
- Vertx vertx = Vertx.vertx();
- vertx.deployVerticle(NearestNeighborsServer.class.getName());
- } catch (Throwable t){
- log.error("Error in NearestNeighboursServer run method",t);
- }
- }
-
- @Override
- public void start() throws Exception {
- instance = this;
-
- String[] pathArr = instanceArgs.ndarrayPath.split(",");
- //INDArray[] pointsArr = new INDArray[pathArr.length];
- // first of all we reading shapes of saved eariler files
- int rows = 0;
- int cols = 0;
- for (int i = 0; i < pathArr.length; i++) {
- DataBuffer shape = BinarySerde.readShapeFromDisk(new File(pathArr[i]));
-
- log.info("Loading shape {} of {}; Shape: [{} x {}]", i + 1, pathArr.length, Shape.size(shape, 0),
- Shape.size(shape, 1));
-
- if (Shape.rank(shape) != 2)
- throw new DL4JInvalidInputException("NearestNeighborsServer assumes 2D chunks");
-
- rows += Shape.size(shape, 0);
-
- if (cols == 0)
- cols = Shape.size(shape, 1);
- else if (cols != Shape.size(shape, 1))
- throw new DL4JInvalidInputException(
- "NearestNeighborsServer requires equal 2D chunks. Got columns mismatch.");
- }
-
- final List labels = new ArrayList<>();
- if (instanceArgs.labelsPath != null) {
- String[] labelsPathArr = instanceArgs.labelsPath.split(",");
- for (int i = 0; i < labelsPathArr.length; i++) {
- labels.addAll(FileUtils.readLines(new File(labelsPathArr[i]), "utf-8"));
- }
- }
- if (!labels.isEmpty() && labels.size() != rows)
- throw new DL4JInvalidInputException(String.format("Number of labels must match number of rows in points matrix (expected %d, found %d)", rows, labels.size()));
-
- final INDArray points = Nd4j.createUninitialized(rows, cols);
-
- int lastPosition = 0;
- for (int i = 0; i < pathArr.length; i++) {
- log.info("Loading chunk {} of {}", i + 1, pathArr.length);
- INDArray pointsArr = BinarySerde.readFromDisk(new File(pathArr[i]));
-
- points.get(NDArrayIndex.interval(lastPosition, lastPosition + pointsArr.rows())).assign(pointsArr);
- lastPosition += pointsArr.rows();
-
- // let's ensure we don't bring too much stuff in next loop
- System.gc();
- }
-
- VPTree tree = new VPTree(points, instanceArgs.similarityFunction, instanceArgs.invert);
-
- //Set play secret key, if required
- //http://www.playframework.com/documentation/latest/ApplicationSecret
- String crypto = System.getProperty("play.crypto.secret");
- if (crypto == null || "changeme".equals(crypto) || "".equals(crypto)) {
- byte[] newCrypto = new byte[1024];
-
- new Random().nextBytes(newCrypto);
-
- String base64 = Base64.getEncoder().encodeToString(newCrypto);
- System.setProperty("play.crypto.secret", base64);
- }
-
- Router r = Router.router(vertx);
- r.route().handler(BodyHandler.create()); //NOTE: Setting this is required to receive request body content at all
- createRoutes(r, labels, tree, points);
-
- vertx.createHttpServer()
- .requestHandler(r)
- .listen(instanceArgs.port);
- }
-
- private void createRoutes(Router r, List labels, VPTree tree, INDArray points){
-
- r.post("/knn").handler(rc -> {
- try {
- String json = rc.getBodyAsJson().encode();
- NearestNeighborRequest record = JsonMappers.getMapper().readValue(json, NearestNeighborRequest.class);
-
- NearestNeighbor nearestNeighbor =
- NearestNeighbor.builder().points(points).record(record).tree(tree).build();
-
- if (record == null) {
- rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
- .putHeader("content-type", "application/json")
- .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
- return;
- }
-
- NearestNeighborsResults results = NearestNeighborsResults.builder().results(nearestNeighbor.search()).build();
-
- rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
- .putHeader("content-type", "application/json")
- .end(JsonMappers.getMapper().writeValueAsString(results));
- return;
- } catch (Throwable e) {
- log.error("Error in POST /knn",e);
- rc.response().setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
- .end("Error parsing request - " + e.getMessage());
- return;
- }
- });
-
- r.post("/knnnew").handler(rc -> {
- try {
- String json = rc.getBodyAsJson().encode();
- Base64NDArrayBody record = JsonMappers.getMapper().readValue(json, Base64NDArrayBody.class);
- if (record == null) {
- rc.response().setStatusCode(HttpResponseStatus.BAD_REQUEST.code())
- .putHeader("content-type", "application/json")
- .end(JsonMappers.getMapper().writeValueAsString(Collections.singletonMap("status", "invalid json passed.")));
- return;
- }
-
- INDArray arr = Nd4jBase64.fromBase64(record.getNdarray());
- List results;
- List distances;
-
- if (record.isForceFillK()) {
- VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(tree, record.getK(), arr);
- vpTreeFillSearch.search();
- results = vpTreeFillSearch.getResults();
- distances = vpTreeFillSearch.getDistances();
- } else {
- results = new ArrayList<>();
- distances = new ArrayList<>();
- tree.search(arr, record.getK(), results, distances);
- }
-
- if (results.size() != distances.size()) {
- rc.response()
- .setStatusCode(HttpResponseStatus.INTERNAL_SERVER_ERROR.code())
- .end(String.format("results.size == %d != %d == distances.size", results.size(), distances.size()));
- return;
- }
-
- List nnResult = new ArrayList<>();
- for (int i=0; i results = nearestNeighbor.search();
- assertEquals(1, results.get(0).getIndex());
- assertEquals(2, results.size());
-
- assertEquals(1.0, results.get(0).getDistance(), 1e-4);
- assertEquals(4.0, results.get(1).getDistance(), 1e-4);
- }
-
- @Test
- public void testNearestNeighborInverted() {
- double[][] data = new double[][] {{1, 2, 3, 4}, {1, 2, 3, 5}, {3, 4, 5, 6}};
- INDArray arr = Nd4j.create(data);
-
- VPTree vpTree = new VPTree(arr, true);
- NearestNeighborRequest request = new NearestNeighborRequest();
- request.setK(2);
- request.setInputIndex(0);
- NearestNeighbor nearestNeighbor = NearestNeighbor.builder().tree(vpTree).points(arr).record(request).build();
- List results = nearestNeighbor.search();
- assertEquals(2, results.get(0).getIndex());
- assertEquals(2, results.size());
-
- assertEquals(-4.0, results.get(0).getDistance(), 1e-4);
- assertEquals(-1.0, results.get(1).getDistance(), 1e-4);
- }
-
- @Test
- public void vpTreeTest() throws Exception {
- INDArray matrix = Nd4j.rand(new int[] {400,10});
- INDArray rowVector = matrix.getRow(70);
- INDArray resultArr = Nd4j.zeros(400,1);
- Executor executor = Executors.newSingleThreadExecutor();
- VPTree vpTree = new VPTree(matrix);
- System.out.println("Ran!");
- }
-
-
-
- public static int getAvailablePort() {
- try {
- ServerSocket socket = new ServerSocket(0);
- try {
- return socket.getLocalPort();
- } finally {
- socket.close();
- }
- } catch (IOException e) {
- throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e);
- }
- }
-
- @Test
- public void testServer() throws Exception {
- int localPort = getAvailablePort();
- Nd4j.getRandom().setSeed(7);
- INDArray rand = Nd4j.randn(10, 5);
- File writeToTmp = testDir.newFile();
- writeToTmp.deleteOnExit();
- BinarySerde.writeArrayToDisk(rand, writeToTmp);
- NearestNeighborsServer.runMain("--ndarrayPath", writeToTmp.getAbsolutePath(), "--nearestNeighborsPort",
- String.valueOf(localPort));
-
- Thread.sleep(3000);
-
- NearestNeighborsClient client = new NearestNeighborsClient("http://localhost:" + localPort);
- NearestNeighborsResults result = client.knnNew(5, rand.getRow(0));
- assertEquals(5, result.getResults().size());
- NearestNeighborsServer.getInstance().stop();
- }
-
-
-
- @Test
- public void testFullSearch() throws Exception {
- int numRows = 1000;
- int numCols = 100;
- int numNeighbors = 42;
- INDArray points = Nd4j.rand(numRows, numCols);
- VPTree tree = new VPTree(points);
- INDArray query = Nd4j.rand(new int[] {1, numCols});
- VPTreeFillSearch fillSearch = new VPTreeFillSearch(tree, numNeighbors, query);
- fillSearch.search();
- List results = fillSearch.getResults();
- List distances = fillSearch.getDistances();
- assertEquals(numNeighbors, distances.size());
- assertEquals(numNeighbors, results.size());
- }
-
- @Test
- public void testDistances() {
-
- INDArray indArray = Nd4j.create(new float[][]{{3, 4}, {1, 2}, {5, 6}});
- INDArray record = Nd4j.create(new float[][]{{7, 6}});
- VPTree vpTree = new VPTree(indArray, "euclidean", false);
- VPTreeFillSearch vpTreeFillSearch = new VPTreeFillSearch(vpTree, 3, record);
- vpTreeFillSearch.search();
- //System.out.println(vpTreeFillSearch.getResults());
- System.out.println(vpTreeFillSearch.getDistances());
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
deleted file mode 100644
index 7e0af0fa1..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbor-server/src/test/resources/logback.xml
+++ /dev/null
@@ -1,46 +0,0 @@
-
-
-
-
-
- logs/application.log
-
- %logger{15} - %message%n%xException{5}
-
-
-
-
-
-
- %logger{15} - %message%n%xException{5}
-
-
-
-
-
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
deleted file mode 100644
index 55d7b83f9..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/pom.xml
+++ /dev/null
@@ -1,60 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-parent
- 1.0.0-SNAPSHOT
-
-
- deeplearning4j-nearestneighbors-client
- jar
-
- deeplearning4j-nearestneighbors-client
-
-
-
- com.mashape.unirest
- unirest-java
- ${unirest.version}
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-model
- ${project.version}
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java
deleted file mode 100644
index 570e75bf9..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-client/src/main/java/org/deeplearning4j/nearestneighbor/client/NearestNeighborsClient.java
+++ /dev/null
@@ -1,137 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.client;
-
-import com.mashape.unirest.http.ObjectMapper;
-import com.mashape.unirest.http.Unirest;
-import com.mashape.unirest.request.HttpRequest;
-import lombok.AllArgsConstructor;
-import lombok.Getter;
-import lombok.Setter;
-import lombok.val;
-import org.deeplearning4j.nearestneighbor.model.*;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.serde.base64.Nd4jBase64;
-import org.nd4j.shade.jackson.core.JsonProcessingException;
-
-import java.io.IOException;
-
-@AllArgsConstructor
-public class NearestNeighborsClient {
-
- private String url;
- @Setter
- @Getter
- protected String authToken;
-
- public NearestNeighborsClient(String url){
- this(url, null);
- }
-
- static {
- // Only one time
-
- Unirest.setObjectMapper(new ObjectMapper() {
- private org.nd4j.shade.jackson.databind.ObjectMapper jacksonObjectMapper =
- new org.nd4j.shade.jackson.databind.ObjectMapper();
-
- public T readValue(String value, Class valueType) {
- try {
- return jacksonObjectMapper.readValue(value, valueType);
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- public String writeValue(Object value) {
- try {
- return jacksonObjectMapper.writeValueAsString(value);
- } catch (JsonProcessingException e) {
- throw new RuntimeException(e);
- }
- }
- });
- }
-
-
- /**
- * Runs knn on the given index
- * with the given k (note that this is for data
- * already within the existing dataset not new data)
- * @param index the index of the
- * EXISTING ndarray
- * to run a search on
- * @param k the number of results
- * @return
- * @throws Exception
- */
- public NearestNeighborsResults knn(int index, int k) throws Exception {
- NearestNeighborRequest request = new NearestNeighborRequest();
- request.setInputIndex(index);
- request.setK(k);
- val req = Unirest.post(url + "/knn");
- req.header("accept", "application/json")
- .header("Content-Type", "application/json").body(request);
- addAuthHeader(req);
-
- NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
- return ret;
- }
-
- /**
- * Run a k nearest neighbors search
- * on a NEW data point
- * @param k the number of results
- * to retrieve
- * @param arr the array to run the search on.
- * Note that this must be a row vector
- * @return
- * @throws Exception
- */
- public NearestNeighborsResults knnNew(int k, INDArray arr) throws Exception {
- Base64NDArrayBody base64NDArrayBody =
- Base64NDArrayBody.builder().k(k).ndarray(Nd4jBase64.base64String(arr)).build();
-
- val req = Unirest.post(url + "/knnnew");
- req.header("accept", "application/json")
- .header("Content-Type", "application/json").body(base64NDArrayBody);
- addAuthHeader(req);
-
- NearestNeighborsResults ret = req.asObject(NearestNeighborsResults.class).getBody();
-
- return ret;
- }
-
-
- /**
- * Add the specified authentication header to the specified HttpRequest
- *
- * @param request HTTP Request to add the authentication header to
- */
- protected HttpRequest addAuthHeader(HttpRequest request) {
- if (authToken != null) {
- request.header("authorization", "Bearer " + authToken);
- }
-
- return request;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
deleted file mode 100644
index 09a72628e..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/pom.xml
+++ /dev/null
@@ -1,61 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-parent
- 1.0.0-SNAPSHOT
-
-
- deeplearning4j-nearestneighbors-model
- jar
-
- deeplearning4j-nearestneighbors-model
-
-
-
- org.projectlombok
- lombok
- ${lombok.version}
- provided
-
-
- org.nd4j
- nd4j-api
- ${nd4j.version}
-
-
-
-
-
- test-nd4j-native
-
-
- test-nd4j-cuda-11.0
-
-
-
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java
deleted file mode 100644
index c68f48ebe..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/Base64NDArrayBody.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.io.Serializable;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-@Builder
-public class Base64NDArrayBody implements Serializable {
- private String ndarray;
- private int k;
- private boolean forceFillK;
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java
deleted file mode 100644
index f2a9475a1..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/BatchRecord.java
+++ /dev/null
@@ -1,65 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.nd4j.linalg.dataset.DataSet;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-
-@Data
-@AllArgsConstructor
-@Builder
-@NoArgsConstructor
-public class BatchRecord implements Serializable {
- private List records;
-
- /**
- * Add a record
- * @param record
- */
- public void add(CSVRecord record) {
- if (records == null)
- records = new ArrayList<>();
- records.add(record);
- }
-
-
- /**
- * Return a batch record based on a dataset
- * @param dataSet the dataset to get the batch record for
- * @return the batch record
- */
- public static BatchRecord fromDataSet(DataSet dataSet) {
- BatchRecord batchRecord = new BatchRecord();
- for (int i = 0; i < dataSet.numExamples(); i++) {
- batchRecord.add(CSVRecord.fromRow(dataSet.get(i)));
- }
-
- return batchRecord;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java
deleted file mode 100644
index ef642bf0d..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/CSVRecord.java
+++ /dev/null
@@ -1,85 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.nd4j.linalg.dataset.DataSet;
-
-import java.io.Serializable;
-
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class CSVRecord implements Serializable {
- private String[] values;
-
- /**
- * Instantiate a csv record from a vector
- * given either an input dataset and a
- * one hot matrix, the index will be appended to
- * the end of the record, or for regression
- * it will append all values in the labels
- * @param row the input vectors
- * @return the record from this {@link DataSet}
- */
- public static CSVRecord fromRow(DataSet row) {
- if (!row.getFeatures().isVector() && !row.getFeatures().isScalar())
- throw new IllegalArgumentException("Passed in dataset must represent a scalar or vector");
- if (!row.getLabels().isVector() && !row.getLabels().isScalar())
- throw new IllegalArgumentException("Passed in dataset labels must be a scalar or vector");
- //classification
- CSVRecord record;
- int idx = 0;
- if (row.getLabels().sumNumber().doubleValue() == 1.0) {
- String[] values = new String[row.getFeatures().columns() + 1];
- for (int i = 0; i < row.getFeatures().length(); i++) {
- values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
- }
- int maxIdx = 0;
- for (int i = 0; i < row.getLabels().length(); i++) {
- if (row.getLabels().getDouble(maxIdx) < row.getLabels().getDouble(i)) {
- maxIdx = i;
- }
- }
-
- values[idx++] = String.valueOf(maxIdx);
- record = new CSVRecord(values);
- }
- //regression (any number of values)
- else {
- String[] values = new String[row.getFeatures().columns() + row.getLabels().columns()];
- for (int i = 0; i < row.getFeatures().length(); i++) {
- values[idx++] = String.valueOf(row.getFeatures().getDouble(i));
- }
- for (int i = 0; i < row.getLabels().length(); i++) {
- values[idx++] = String.valueOf(row.getLabels().getDouble(i));
- }
-
-
- record = new CSVRecord(values);
-
- }
- return record;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java
deleted file mode 100644
index 5044c6b35..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborRequest.java
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.Data;
-
-import java.io.Serializable;
-
-@Data
-public class NearestNeighborRequest implements Serializable {
- private int k;
- private int inputIndex;
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java
deleted file mode 100644
index 768b0dfc9..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResult.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-@Data
-@AllArgsConstructor
-@NoArgsConstructor
-public class NearestNeighborsResult {
- public NearestNeighborsResult(int index, double distance) {
- this(index, distance, null);
- }
-
- private int index;
- private double distance;
- private String label;
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java
deleted file mode 100644
index d95c68fb6..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/deeplearning4j-nearestneighbors-model/src/main/java/org/deeplearning4j/nearestneighbor/model/NearestNeighborsResults.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.nearestneighbor.model;
-
-import lombok.AllArgsConstructor;
-import lombok.Builder;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.io.Serializable;
-import java.util.List;
-
-@Data
-@Builder
-@NoArgsConstructor
-@AllArgsConstructor
-public class NearestNeighborsResults implements Serializable {
- private List results;
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
deleted file mode 100644
index 5df85229d..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/pom.xml
+++ /dev/null
@@ -1,103 +0,0 @@
-
-
-
-
-
- 4.0.0
-
-
- org.deeplearning4j
- deeplearning4j-nearestneighbors-parent
- 1.0.0-SNAPSHOT
-
-
- nearestneighbor-core
- jar
-
- nearestneighbor-core
-
-
-
- org.nd4j
- nd4j-api
- ${nd4j.version}
-
-
- junit
- junit
-
-
- ch.qos.logback
- logback-classic
- test
-
-
- org.deeplearning4j
- deeplearning4j-nn
- ${project.version}
-
-
- org.deeplearning4j
- deeplearning4j-datasets
- ${project.version}
- test
-
-
- joda-time
- joda-time
- 2.10.3
- test
-
-
- org.deeplearning4j
- deeplearning4j-common-tests
- ${project.version}
- test
-
-
-
-
-
- test-nd4j-native
-
-
- org.nd4j
- nd4j-native
- ${project.version}
- test
-
-
-
-
- test-nd4j-cuda-11.0
-
-
- org.nd4j
- nd4j-cuda-11.0
- ${project.version}
- test
-
-
-
-
-
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java
deleted file mode 100755
index e7e467ad3..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/BaseClusteringAlgorithm.java
+++ /dev/null
@@ -1,218 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.algorithm;
-
-import lombok.AccessLevel;
-import lombok.NoArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import lombok.val;
-import org.apache.commons.lang3.ArrayUtils;
-import org.deeplearning4j.clustering.cluster.Cluster;
-import org.deeplearning4j.clustering.cluster.ClusterSet;
-import org.deeplearning4j.clustering.cluster.ClusterUtils;
-import org.deeplearning4j.clustering.cluster.Point;
-import org.deeplearning4j.clustering.info.ClusterSetInfo;
-import org.deeplearning4j.clustering.iteration.IterationHistory;
-import org.deeplearning4j.clustering.iteration.IterationInfo;
-import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
-import org.deeplearning4j.clustering.strategy.ClusteringStrategyType;
-import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
-import org.deeplearning4j.clustering.util.MultiThreadUtils;
-import org.nd4j.common.base.Preconditions;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.ExecutorService;
-
-@Slf4j
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-public class BaseClusteringAlgorithm implements ClusteringAlgorithm, Serializable {
-
- private static final long serialVersionUID = 338231277453149972L;
-
- private ClusteringStrategy clusteringStrategy;
- private IterationHistory iterationHistory;
- private int currentIteration = 0;
- private ClusterSet clusterSet;
- private List initialPoints;
- private transient ExecutorService exec;
- private boolean useKmeansPlusPlus;
-
-
- protected BaseClusteringAlgorithm(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
- this.clusteringStrategy = clusteringStrategy;
- this.exec = MultiThreadUtils.newExecutorService();
- this.useKmeansPlusPlus = useKmeansPlusPlus;
- }
-
- /**
- *
- * @param clusteringStrategy
- * @return
- */
- public static BaseClusteringAlgorithm setup(ClusteringStrategy clusteringStrategy, boolean useKmeansPlusPlus) {
- return new BaseClusteringAlgorithm(clusteringStrategy, useKmeansPlusPlus);
- }
-
- /**
- *
- * @param points
- * @return
- */
- public ClusterSet applyTo(List points) {
- resetState(points);
- initClusters(useKmeansPlusPlus);
- iterations();
- return clusterSet;
- }
-
- private void resetState(List points) {
- this.iterationHistory = new IterationHistory();
- this.currentIteration = 0;
- this.clusterSet = null;
- this.initialPoints = points;
- }
-
- /** Run clustering iterations until a
- * termination condition is hit.
- * This is done by first classifying all points,
- * and then updating cluster centers based on
- * those classified points
- */
- private void iterations() {
- int iterationCount = 0;
- while ((clusteringStrategy.getTerminationCondition() != null
- && !clusteringStrategy.getTerminationCondition().isSatisfied(iterationHistory))
- || iterationHistory.getMostRecentIterationInfo().isStrategyApplied()) {
- currentIteration++;
- removePoints();
- classifyPoints();
- applyClusteringStrategy();
- log.trace("Completed clustering iteration {}", ++iterationCount);
- }
- }
-
- protected void classifyPoints() {
- //Classify points. This also adds each point to the ClusterSet
- ClusterSetInfo clusterSetInfo = ClusterUtils.classifyPoints(clusterSet, initialPoints, exec);
- //Update the cluster centers, based on the points within each cluster
- ClusterUtils.refreshClustersCenters(clusterSet, clusterSetInfo, exec);
- iterationHistory.getIterationsInfos().put(currentIteration,
- new IterationInfo(currentIteration, clusterSetInfo));
- }
-
- /**
- * Initialize the
- * cluster centers at random
- */
- protected void initClusters(boolean kMeansPlusPlus) {
- log.info("Generating initial clusters");
- List points = new ArrayList<>(initialPoints);
-
- //Initialize the ClusterSet with a single cluster center (based on position of one of the points chosen randomly)
- val random = Nd4j.getRandom();
- Distance distanceFn = clusteringStrategy.getDistanceFunction();
- int initialClusterCount = clusteringStrategy.getInitialClusterCount();
- clusterSet = new ClusterSet(distanceFn,
- clusteringStrategy.inverseDistanceCalculation(), new long[]{initialClusterCount, points.get(0).getArray().length()});
- clusterSet.addNewClusterWithCenter(points.remove(random.nextInt(points.size())));
-
-
- //dxs: distances between
- // each point and nearest cluster to that point
- INDArray dxs = Nd4j.create(points.size());
- dxs.addi(clusteringStrategy.inverseDistanceCalculation() ? -Double.MAX_VALUE : Double.MAX_VALUE);
-
- //Generate the initial cluster centers, by randomly selecting a point between 0 and max distance
- //Thus, we are more likely to select (as a new cluster center) a point that is far from an existing cluster
- while (clusterSet.getClusterCount() < initialClusterCount && !points.isEmpty()) {
- dxs = ClusterUtils.computeSquareDistancesFromNearestCluster(clusterSet, points, dxs, exec);
- double summed = Nd4j.sum(dxs).getDouble(0);
- double r = kMeansPlusPlus ? random.nextDouble() * summed:
- random.nextFloat() * dxs.maxNumber().doubleValue();
-
- for (int i = 0; i < dxs.length(); i++) {
- double distance = dxs.getDouble(i);
- Preconditions.checkState(distance >= 0, "Encountered negative distance: distance function is not valid? Distance " +
- "function must return values >= 0, got distance %s for function s", distance, distanceFn);
- if (dxs.getDouble(i) >= r) {
- clusterSet.addNewClusterWithCenter(points.remove(i));
- dxs = Nd4j.create(ArrayUtils.remove(dxs.data().asDouble(), i));
- break;
- }
- }
- }
-
- ClusterSetInfo initialClusterSetInfo = ClusterUtils.computeClusterSetInfo(clusterSet);
- iterationHistory.getIterationsInfos().put(currentIteration,
- new IterationInfo(currentIteration, initialClusterSetInfo));
- }
-
-
- protected void applyClusteringStrategy() {
- if (!isStrategyApplicableNow())
- return;
-
- ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
- if (!clusteringStrategy.isAllowEmptyClusters()) {
- int removedCount = removeEmptyClusters(clusterSetInfo);
- if (removedCount > 0) {
- iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
-
- if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.FIXED_CLUSTER_COUNT)
- && clusterSet.getClusterCount() < clusteringStrategy.getInitialClusterCount()) {
- int splitCount = ClusterUtils.splitMostSpreadOutClusters(clusterSet, clusterSetInfo,
- clusteringStrategy.getInitialClusterCount() - clusterSet.getClusterCount(), exec);
- if (splitCount > 0)
- iterationHistory.getMostRecentIterationInfo().setStrategyApplied(true);
- }
- }
- }
- if (clusteringStrategy.isStrategyOfType(ClusteringStrategyType.OPTIMIZATION))
- optimize();
- }
-
- protected void optimize() {
- ClusterSetInfo clusterSetInfo = iterationHistory.getMostRecentClusterSetInfo();
- OptimisationStrategy optimization = (OptimisationStrategy) clusteringStrategy;
- boolean applied = ClusterUtils.applyOptimization(optimization, clusterSet, clusterSetInfo, exec);
- iterationHistory.getMostRecentIterationInfo().setStrategyApplied(applied);
- }
-
- private boolean isStrategyApplicableNow() {
- return clusteringStrategy.isOptimizationDefined() && iterationHistory.getIterationCount() != 0
- && clusteringStrategy.isOptimizationApplicableNow(iterationHistory);
- }
-
- protected int removeEmptyClusters(ClusterSetInfo clusterSetInfo) {
- List removedClusters = clusterSet.removeEmptyClusters();
- clusterSetInfo.removeClusterInfos(removedClusters);
- return removedClusters.size();
- }
-
- protected void removePoints() {
- clusterSet.removePoints();
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java
deleted file mode 100644
index 02ac17f39..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/ClusteringAlgorithm.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.algorithm;
-
-import org.deeplearning4j.clustering.cluster.ClusterSet;
-import org.deeplearning4j.clustering.cluster.Point;
-
-import java.util.List;
-
-public interface ClusteringAlgorithm {
-
- /**
- * Apply a clustering
- * algorithm for a given result
- * @param points
- * @return
- */
- ClusterSet applyTo(List points);
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java
deleted file mode 100644
index 657df3dfa..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/algorithm/Distance.java
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.algorithm;
-
-public enum Distance {
- EUCLIDEAN("euclidean"),
- COSINE_DISTANCE("cosinedistance"),
- COSINE_SIMILARITY("cosinesimilarity"),
- MANHATTAN("manhattan"),
- DOT("dot"),
- JACCARD("jaccard"),
- HAMMING("hamming");
-
- private String functionName;
- private Distance(String name) {
- functionName = name;
- }
-
- @Override
- public String toString() {
- return functionName;
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java
deleted file mode 100644
index 8a39d8bc3..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import org.deeplearning4j.clustering.algorithm.Distance;
-import org.nd4j.linalg.api.buffer.DataType;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.ReduceOp;
-import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-
-public class CentersHolder {
- private INDArray centers;
- private long index = 0;
-
- protected transient ReduceOp op;
- protected ArgMin imin;
- protected transient INDArray distances;
- protected transient INDArray argMin;
-
- private long rows, cols;
-
- public CentersHolder(long rows, long cols) {
- this.rows = rows;
- this.cols = cols;
- }
-
- public INDArray getCenters() {
- return this.centers;
- }
-
- public synchronized void addCenter(INDArray pointView) {
- if (centers == null)
- this.centers = Nd4j.create(pointView.dataType(), new long[] {rows, cols});
-
- centers.putRow(index++, pointView);
- }
-
- public synchronized Pair getCenterByMinDistance(Point point, Distance distanceFunction) {
- if (distances == null)
- distances = Nd4j.create(centers.dataType(), centers.rows());
-
- if (argMin == null)
- argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
-
- if (op == null) {
- op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
- imin = new ArgMin(distances, argMin);
- op.setZ(distances);
- }
-
- op.setY(point.getArray());
-
- Nd4j.getExecutioner().exec(op);
- Nd4j.getExecutioner().exec(imin);
-
- Pair result = new Pair<>();
- result.setFirst(distances.getDouble(argMin.getLong(0)));
- result.setSecond(argMin.getLong(0));
- return result;
- }
-
- public synchronized INDArray getMinDistances(Point point, Distance distanceFunction) {
- if (distances == null)
- distances = Nd4j.create(centers.dataType(), centers.rows());
-
- if (argMin == null)
- argMin = Nd4j.createUninitialized(DataType.LONG, new long[0]);
-
- if (op == null) {
- op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1);
- imin = new ArgMin(distances, argMin);
- op.setZ(distances);
- }
-
- op.setY(point.getArray());
-
- Nd4j.getExecutioner().exec(op);
- Nd4j.getExecutioner().exec(imin);
-
- System.out.println(distances);
- return distances;
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java
deleted file mode 100644
index 7f4f221e5..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Cluster.java
+++ /dev/null
@@ -1,150 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import lombok.Data;
-import org.deeplearning4j.clustering.algorithm.Distance;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.UUID;
-
-@Data
-public class Cluster implements Serializable {
-
- private String id = UUID.randomUUID().toString();
- private String label;
-
- private Point center;
- private List points = Collections.synchronizedList(new ArrayList());
- private boolean inverse = false;
- private Distance distanceFunction;
-
- public Cluster() {
- super();
- }
-
- /**
- *
- * @param center
- * @param distanceFunction
- */
- public Cluster(Point center, Distance distanceFunction) {
- this(center, false, distanceFunction);
- }
-
- /**
- *
- * @param center
- * @param distanceFunction
- */
- public Cluster(Point center, boolean inverse, Distance distanceFunction) {
- this.distanceFunction = distanceFunction;
- this.inverse = inverse;
- setCenter(center);
- }
-
- /**
- * Get the distance to the given
- * point from the cluster
- * @param point the point to get the distance for
- * @return
- */
- public double getDistanceToCenter(Point point) {
- return Nd4j.getExecutioner().execAndReturn(
- ClusterUtils.createDistanceFunctionOp(distanceFunction, center.getArray(), point.getArray()))
- .getFinalResult().doubleValue();
- }
-
- /**
- * Add a point to the cluster
- * @param point
- */
- public void addPoint(Point point) {
- addPoint(point, true);
- }
-
- /**
- * Add a point to the cluster
- * @param point the point to add
- * @param moveClusterCenter whether to update
- * the cluster centroid or not
- */
- public void addPoint(Point point, boolean moveClusterCenter) {
- if (moveClusterCenter) {
- if (isInverse()) {
- center.getArray().muli(points.size()).subi(point.getArray()).divi(points.size() + 1);
- } else {
- center.getArray().muli(points.size()).addi(point.getArray()).divi(points.size() + 1);
- }
- }
-
- getPoints().add(point);
- }
-
- /**
- * Clear out the ponits
- */
- public void removePoints() {
- if (getPoints() != null)
- getPoints().clear();
- }
-
- /**
- * Whether the cluster is empty or not
- * @return
- */
- public boolean isEmpty() {
- return points == null || points.isEmpty();
- }
-
- /**
- * Return the point with the given id
- * @param id
- * @return
- */
- public Point getPoint(String id) {
- for (Point point : points)
- if (id.equals(point.getId()))
- return point;
- return null;
- }
-
- /**
- * Remove the point and return it
- * @param id
- * @return
- */
- public Point removePoint(String id) {
- Point removePoint = null;
- for (Point point : points)
- if (id.equals(point.getId()))
- removePoint = point;
- if (removePoint != null)
- points.remove(removePoint);
- return removePoint;
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java
deleted file mode 100644
index dabfdc7a4..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterSet.java
+++ /dev/null
@@ -1,259 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import lombok.Data;
-import org.deeplearning4j.clustering.algorithm.Distance;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-
-import java.io.Serializable;
-import java.util.*;
-
-@Data
-public class ClusterSet implements Serializable {
-
- private Distance distanceFunction;
- private List clusters;
- private CentersHolder centersHolder;
- private Map pointDistribution;
- private boolean inverse;
-
- public ClusterSet(boolean inverse) {
- this(null, inverse, null);
- }
-
- public ClusterSet(Distance distanceFunction, boolean inverse, long[] shape) {
- this.distanceFunction = distanceFunction;
- this.inverse = inverse;
- this.clusters = Collections.synchronizedList(new ArrayList());
- this.pointDistribution = Collections.synchronizedMap(new HashMap());
- if (shape != null)
- this.centersHolder = new CentersHolder(shape[0], shape[1]);
- }
-
-
- public boolean isInverse() {
- return inverse;
- }
-
- /**
- *
- * @param center
- * @return
- */
- public Cluster addNewClusterWithCenter(Point center) {
- Cluster newCluster = new Cluster(center, distanceFunction);
- getClusters().add(newCluster);
- setPointLocation(center, newCluster);
- centersHolder.addCenter(center.getArray());
- return newCluster;
- }
-
- /**
- *
- * @param point
- * @return
- */
- public PointClassification classifyPoint(Point point) {
- return classifyPoint(point, true);
- }
-
- /**
- *
- * @param points
- */
- public void classifyPoints(List points) {
- classifyPoints(points, true);
- }
-
- /**
- *
- * @param points
- * @param moveClusterCenter
- */
- public void classifyPoints(List points, boolean moveClusterCenter) {
- for (Point point : points)
- classifyPoint(point, moveClusterCenter);
- }
-
- /**
- *
- * @param point
- * @param moveClusterCenter
- * @return
- */
- public PointClassification classifyPoint(Point point, boolean moveClusterCenter) {
- Pair nearestCluster = nearestCluster(point);
- Cluster newCluster = nearestCluster.getKey();
- boolean locationChange = isPointLocationChange(point, newCluster);
- addPointToCluster(point, newCluster, moveClusterCenter);
- return new PointClassification(nearestCluster.getKey(), nearestCluster.getValue(), locationChange);
- }
-
- private boolean isPointLocationChange(Point point, Cluster newCluster) {
- if (!getPointDistribution().containsKey(point.getId()))
- return true;
- return !getPointDistribution().get(point.getId()).equals(newCluster.getId());
- }
-
- private void addPointToCluster(Point point, Cluster cluster, boolean moveClusterCenter) {
- cluster.addPoint(point, moveClusterCenter);
- setPointLocation(point, cluster);
- }
-
- private void setPointLocation(Point point, Cluster cluster) {
- pointDistribution.put(point.getId(), cluster.getId());
- }
-
-
- /**
- *
- * @param point
- * @return
- */
- public Pair nearestCluster(Point point) {
-
- /*double minDistance = isInverse() ? Float.MIN_VALUE : Float.MAX_VALUE;
-
- double currentDistance;
- for (Cluster cluster : getClusters()) {
- currentDistance = cluster.getDistanceToCenter(point);
- if (isInverse()) {
- if (currentDistance > minDistance) {
- minDistance = currentDistance;
- nearestCluster = cluster;
- }
- } else {
- if (currentDistance < minDistance) {
- minDistance = currentDistance;
- nearestCluster = cluster;
- }
- }
-
- }*/
-
- Pair nearestCenterData = centersHolder.
- getCenterByMinDistance(point, distanceFunction);
- Cluster nearestCluster = getClusters().get(nearestCenterData.getSecond().intValue());
- double minDistance = nearestCenterData.getFirst();
- return Pair.of(nearestCluster, minDistance);
- }
-
- /**
- *
- * @param m1
- * @param m2
- * @return
- */
- public double getDistance(Point m1, Point m2) {
- return Nd4j.getExecutioner()
- .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction, m1.getArray(), m2.getArray()))
- .getFinalResult().doubleValue();
- }
-
- /**
- *
- * @param point
- * @return
- */
- /*public double getDistanceFromNearestCluster(Point point) {
- return nearestCluster(point).getValue();
- }*/
-
-
- /**
- *
- * @param clusterId
- * @return
- */
- public String getClusterCenterId(String clusterId) {
- Point clusterCenter = getClusterCenter(clusterId);
- return clusterCenter == null ? null : clusterCenter.getId();
- }
-
- /**
- *
- * @param clusterId
- * @return
- */
- public Point getClusterCenter(String clusterId) {
- Cluster cluster = getCluster(clusterId);
- return cluster == null ? null : cluster.getCenter();
- }
-
- /**
- *
- * @param id
- * @return
- */
- public Cluster getCluster(String id) {
- for (int i = 0, j = clusters.size(); i < j; i++)
- if (id.equals(clusters.get(i).getId()))
- return clusters.get(i);
- return null;
- }
-
- /**
- *
- * @return
- */
- public int getClusterCount() {
- return getClusters() == null ? 0 : getClusters().size();
- }
-
- /**
- *
- */
- public void removePoints() {
- for (Cluster cluster : getClusters())
- cluster.removePoints();
- }
-
- /**
- *
- * @param count
- * @return
- */
- public List getMostPopulatedClusters(int count) {
- List mostPopulated = new ArrayList<>(clusters);
- Collections.sort(mostPopulated, new Comparator() {
- public int compare(Cluster o1, Cluster o2) {
- return Integer.compare(o2.getPoints().size(), o1.getPoints().size());
- }
- });
- return mostPopulated.subList(0, count);
- }
-
- /**
- *
- * @return
- */
- public List removeEmptyClusters() {
- List emptyClusters = new ArrayList<>();
- for (Cluster cluster : clusters)
- if (cluster.isEmpty())
- emptyClusters.add(cluster);
- clusters.removeAll(emptyClusters);
- return emptyClusters;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java
deleted file mode 100644
index ac1786538..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/ClusterUtils.java
+++ /dev/null
@@ -1,531 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import lombok.AccessLevel;
-import lombok.NoArgsConstructor;
-import lombok.extern.slf4j.Slf4j;
-import lombok.val;
-import org.apache.commons.lang3.ArrayUtils;
-import org.deeplearning4j.clustering.algorithm.Distance;
-import org.deeplearning4j.clustering.info.ClusterInfo;
-import org.deeplearning4j.clustering.info.ClusterSetInfo;
-import org.deeplearning4j.clustering.optimisation.ClusteringOptimizationType;
-import org.deeplearning4j.clustering.strategy.OptimisationStrategy;
-import org.deeplearning4j.clustering.util.MathUtils;
-import org.deeplearning4j.clustering.util.MultiThreadUtils;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.ReduceOp;
-import org.nd4j.linalg.api.ops.impl.reduce3.*;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.util.*;
-import java.util.concurrent.ExecutorService;
-
-@NoArgsConstructor(access = AccessLevel.PRIVATE)
-@Slf4j
-public class ClusterUtils {
-
- /** Classify the set of points base on cluster centers. This also adds each point to the ClusterSet */
- public static ClusterSetInfo classifyPoints(final ClusterSet clusterSet, List points,
- ExecutorService executorService) {
- final ClusterSetInfo clusterSetInfo = ClusterSetInfo.initialize(clusterSet, true);
-
- List tasks = new ArrayList<>();
- for (final Point point : points) {
- //tasks.add(new Runnable() {
- // public void run() {
- try {
- PointClassification result = classifyPoint(clusterSet, point);
- if (result.isNewLocation())
- clusterSetInfo.getPointLocationChange().incrementAndGet();
- clusterSetInfo.getClusterInfo(result.getCluster().getId()).getPointDistancesFromCenter()
- .put(point.getId(), result.getDistanceFromCenter());
- } catch (Throwable t) {
- log.warn("Error classifying point", t);
- }
- // }
- }
-
- //MultiThreadUtils.parallelTasks(tasks, executorService);
- return clusterSetInfo;
- }
-
- public static PointClassification classifyPoint(ClusterSet clusterSet, Point point) {
- return clusterSet.classifyPoint(point, false);
- }
-
- public static void refreshClustersCenters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
- ExecutorService executorService) {
- List tasks = new ArrayList<>();
- int nClusters = clusterSet.getClusterCount();
- for (int i = 0; i < nClusters; i++) {
- final Cluster cluster = clusterSet.getClusters().get(i);
- //tasks.add(new Runnable() {
- // public void run() {
- try {
- final ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
- refreshClusterCenter(cluster, clusterInfo);
- deriveClusterInfoDistanceStatistics(clusterInfo);
- } catch (Throwable t) {
- log.warn("Error refreshing cluster centers", t);
- }
- // }
- //});
- }
- //MultiThreadUtils.parallelTasks(tasks, executorService);
- }
-
- public static void refreshClusterCenter(Cluster cluster, ClusterInfo clusterInfo) {
- int pointsCount = cluster.getPoints().size();
- if (pointsCount == 0)
- return;
- Point center = new Point(Nd4j.create(cluster.getPoints().get(0).getArray().length()));
- for (Point point : cluster.getPoints()) {
- INDArray arr = point.getArray();
- if (cluster.isInverse())
- center.getArray().subi(arr);
- else
- center.getArray().addi(arr);
- }
- center.getArray().divi(pointsCount);
- cluster.setCenter(center);
- }
-
- /**
- *
- * @param info
- */
- public static void deriveClusterInfoDistanceStatistics(ClusterInfo info) {
- int pointCount = info.getPointDistancesFromCenter().size();
- if (pointCount == 0)
- return;
-
- double[] distances =
- ArrayUtils.toPrimitive(info.getPointDistancesFromCenter().values().toArray(new Double[] {}));
- double max = info.isInverse() ? MathUtils.min(distances) : MathUtils.max(distances);
- double total = MathUtils.sum(distances);
- info.setMaxPointDistanceFromCenter(max);
- info.setTotalPointDistanceFromCenter(total);
- info.setAveragePointDistanceFromCenter(total / pointCount);
- info.setPointDistanceFromCenterVariance(MathUtils.variance(distances));
- }
-
- /**
- *
- * @param clusterSet
- * @param points
- * @param previousDxs
- * @param executorService
- * @return
- */
- public static INDArray computeSquareDistancesFromNearestCluster(final ClusterSet clusterSet,
- final List points, INDArray previousDxs, ExecutorService executorService) {
- final int pointsCount = points.size();
- final INDArray dxs = Nd4j.create(pointsCount);
- final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
-
- List tasks = new ArrayList<>();
- for (int i = 0; i < pointsCount; i++) {
- final int i2 = i;
- //tasks.add(new Runnable() {
- // public void run() {
- try {
- Point point = points.get(i2);
- double dist = clusterSet.isInverse() ? newCluster.getDistanceToCenter(point)
- : Math.pow(newCluster.getDistanceToCenter(point), 2);
- dxs.putScalar(i2, /*clusterSet.isInverse() ? dist :*/ dist);
- } catch (Throwable t) {
- log.warn("Error computing squared distance from nearest cluster", t);
- }
- // }
- //});
-
- }
-
- //MultiThreadUtils.parallelTasks(tasks, executorService);
- for (int i = 0; i < pointsCount; i++) {
- double previousMinDistance = previousDxs.getDouble(i);
- if (clusterSet.isInverse()) {
- if (dxs.getDouble(i) < previousMinDistance) {
-
- dxs.putScalar(i, previousMinDistance);
- }
- } else if (dxs.getDouble(i) > previousMinDistance)
- dxs.putScalar(i, previousMinDistance);
- }
-
- return dxs;
- }
-
- public static INDArray computeWeightedProbaDistancesFromNearestCluster(final ClusterSet clusterSet,
- final List points, INDArray previousDxs) {
- final int pointsCount = points.size();
- final INDArray dxs = Nd4j.create(pointsCount);
- final Cluster newCluster = clusterSet.getClusters().get(clusterSet.getClusters().size() - 1);
-
- Double sum = new Double(0);
- for (int i = 0; i < pointsCount; i++) {
-
- Point point = points.get(i);
- double dist = Math.pow(newCluster.getDistanceToCenter(point), 2);
- sum += dist;
- dxs.putScalar(i, sum);
- }
-
- return dxs;
- }
- /**
- *
- * @param clusterSet
- * @return
- */
- public static ClusterSetInfo computeClusterSetInfo(ClusterSet clusterSet) {
- ExecutorService executor = MultiThreadUtils.newExecutorService();
- ClusterSetInfo info = computeClusterSetInfo(clusterSet, executor);
- executor.shutdownNow();
- return info;
- }
-
- public static ClusterSetInfo computeClusterSetInfo(final ClusterSet clusterSet, ExecutorService executorService) {
- final ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), true);
- int clusterCount = clusterSet.getClusterCount();
-
- List tasks = new ArrayList<>();
- for (int i = 0; i < clusterCount; i++) {
- final Cluster cluster = clusterSet.getClusters().get(i);
- //tasks.add(new Runnable() {
- // public void run() {
- try {
- info.getClustersInfos().put(cluster.getId(),
- computeClusterInfos(cluster, clusterSet.getDistanceFunction()));
- } catch (Throwable t) {
- log.warn("Error computing cluster set info", t);
- }
- //}
- //});
- }
-
-
- //MultiThreadUtils.parallelTasks(tasks, executorService);
-
- //tasks = new ArrayList<>();
- for (int i = 0; i < clusterCount; i++) {
- final int clusterIdx = i;
- final Cluster fromCluster = clusterSet.getClusters().get(i);
- //tasks.add(new Runnable() {
- //public void run() {
- try {
- for (int k = clusterIdx + 1, l = clusterSet.getClusterCount(); k < l; k++) {
- Cluster toCluster = clusterSet.getClusters().get(k);
- double distance = Nd4j.getExecutioner()
- .execAndReturn(ClusterUtils.createDistanceFunctionOp(
- clusterSet.getDistanceFunction(),
- fromCluster.getCenter().getArray(),
- toCluster.getCenter().getArray()))
- .getFinalResult().doubleValue();
- info.getDistancesBetweenClustersCenters().put(fromCluster.getId(), toCluster.getId(),
- distance);
- }
- } catch (Throwable t) {
- log.warn("Error computing distances", t);
- }
- // }
- //});
-
- }
-
- //MultiThreadUtils.parallelTasks(tasks, executorService);
-
- return info;
- }
-
- /**
- *
- * @param cluster
- * @param distanceFunction
- * @return
- */
- public static ClusterInfo computeClusterInfos(Cluster cluster, Distance distanceFunction) {
- ClusterInfo info = new ClusterInfo(cluster.isInverse(), true);
- for (int i = 0, j = cluster.getPoints().size(); i < j; i++) {
- Point point = cluster.getPoints().get(i);
- //shouldn't need to inverse here. other parts of
- //the code should interpret the "distance" or score here
- double distance = Nd4j.getExecutioner()
- .execAndReturn(ClusterUtils.createDistanceFunctionOp(distanceFunction,
- cluster.getCenter().getArray(), point.getArray()))
- .getFinalResult().doubleValue();
- info.getPointDistancesFromCenter().put(point.getId(), distance);
- double diff = info.getTotalPointDistanceFromCenter() + distance;
- info.setTotalPointDistanceFromCenter(diff);
- }
-
- if (!cluster.getPoints().isEmpty())
- info.setAveragePointDistanceFromCenter(info.getTotalPointDistanceFromCenter() / cluster.getPoints().size());
- return info;
- }
-
- /**
- *
- * @param optimization
- * @param clusterSet
- * @param clusterSetInfo
- * @param executor
- * @return
- */
- public static boolean applyOptimization(OptimisationStrategy optimization, ClusterSet clusterSet,
- ClusterSetInfo clusterSetInfo, ExecutorService executor) {
-
- if (optimization.isClusteringOptimizationType(
- ClusteringOptimizationType.MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE)) {
- int splitCount = ClusterUtils.splitClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet,
- clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
- return splitCount > 0;
- }
-
- if (optimization.isClusteringOptimizationType(
- ClusteringOptimizationType.MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE)) {
- int splitCount = ClusterUtils.splitClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet,
- clusterSetInfo, optimization.getClusteringOptimizationValue(), executor);
- return splitCount > 0;
- }
-
- return false;
- }
-
- /**
- *
- * @param clusterSet
- * @param info
- * @param count
- * @return
- */
- public static List getMostSpreadOutClusters(final ClusterSet clusterSet, final ClusterSetInfo info,
- int count) {
- List clusters = new ArrayList<>(clusterSet.getClusters());
- Collections.sort(clusters, new Comparator() {
- public int compare(Cluster o1, Cluster o2) {
- Double o1TotalDistance = info.getClusterInfo(o1.getId()).getTotalPointDistanceFromCenter();
- Double o2TotalDistance = info.getClusterInfo(o2.getId()).getTotalPointDistanceFromCenter();
- int comp = o1TotalDistance.compareTo(o2TotalDistance);
- return !clusterSet.getClusters().get(0).isInverse() ? -comp : comp;
- }
- });
-
- return clusters.subList(0, count);
- }
-
- /**
- *
- * @param clusterSet
- * @param info
- * @param maximumAverageDistance
- * @return
- */
- public static List getClustersWhereAverageDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
- final ClusterSetInfo info, double maximumAverageDistance) {
- List clusters = new ArrayList<>();
- for (Cluster cluster : clusterSet.getClusters()) {
- ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
- if (clusterInfo != null) {
- //distances
- if (clusterInfo.isInverse()) {
- if (clusterInfo.getAveragePointDistanceFromCenter() < maximumAverageDistance)
- clusters.add(cluster);
- } else {
- if (clusterInfo.getAveragePointDistanceFromCenter() > maximumAverageDistance)
- clusters.add(cluster);
- }
-
- }
-
- }
- return clusters;
- }
-
- /**
- *
- * @param clusterSet
- * @param info
- * @param maximumDistance
- * @return
- */
- public static List getClustersWhereMaximumDistanceFromCenterGreaterThan(final ClusterSet clusterSet,
- final ClusterSetInfo info, double maximumDistance) {
- List clusters = new ArrayList<>();
- for (Cluster cluster : clusterSet.getClusters()) {
- ClusterInfo clusterInfo = info.getClusterInfo(cluster.getId());
- if (clusterInfo != null) {
- if (clusterInfo.isInverse() && clusterInfo.getMaxPointDistanceFromCenter() < maximumDistance) {
- clusters.add(cluster);
- } else if (clusterInfo.getMaxPointDistanceFromCenter() > maximumDistance) {
- clusters.add(cluster);
-
- }
- }
- }
- return clusters;
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param count
- * @param executorService
- * @return
- */
- public static int splitMostSpreadOutClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
- ExecutorService executorService) {
- List clustersToSplit = getMostSpreadOutClusters(clusterSet, clusterSetInfo, count);
- splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
- return clustersToSplit.size();
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param maxWithinClusterDistance
- * @param executorService
- * @return
- */
- public static int splitClustersWhereAverageDistanceFromCenterGreaterThan(ClusterSet clusterSet,
- ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
- List clustersToSplit = getClustersWhereAverageDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
- maxWithinClusterDistance);
- splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
- return clustersToSplit.size();
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param maxWithinClusterDistance
- * @param executorService
- * @return
- */
- public static int splitClustersWhereMaximumDistanceFromCenterGreaterThan(ClusterSet clusterSet,
- ClusterSetInfo clusterSetInfo, double maxWithinClusterDistance, ExecutorService executorService) {
- List clustersToSplit = getClustersWhereMaximumDistanceFromCenterGreaterThan(clusterSet, clusterSetInfo,
- maxWithinClusterDistance);
- splitClusters(clusterSet, clusterSetInfo, clustersToSplit, maxWithinClusterDistance, executorService);
- return clustersToSplit.size();
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param count
- * @param executorService
- */
- public static void splitMostPopulatedClusters(ClusterSet clusterSet, ClusterSetInfo clusterSetInfo, int count,
- ExecutorService executorService) {
- List clustersToSplit = clusterSet.getMostPopulatedClusters(count);
- splitClusters(clusterSet, clusterSetInfo, clustersToSplit, executorService);
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param clusters
- * @param maxDistance
- * @param executorService
- */
- public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
- List clusters, final double maxDistance, ExecutorService executorService) {
- final Random random = new Random();
- List tasks = new ArrayList<>();
- for (final Cluster cluster : clusters) {
- tasks.add(new Runnable() {
- public void run() {
- try {
- ClusterInfo clusterInfo = clusterSetInfo.getClusterInfo(cluster.getId());
- List fartherPoints = clusterInfo.getPointsFartherFromCenterThan(maxDistance);
- int rank = Math.min(fartherPoints.size(), 3);
- String pointId = fartherPoints.get(random.nextInt(rank));
- Point point = cluster.removePoint(pointId);
- clusterSet.addNewClusterWithCenter(point);
- } catch (Throwable t) {
- log.warn("Error splitting clusters", t);
- }
- }
- });
- }
- MultiThreadUtils.parallelTasks(tasks, executorService);
- }
-
- /**
- *
- * @param clusterSet
- * @param clusterSetInfo
- * @param clusters
- * @param executorService
- */
- public static void splitClusters(final ClusterSet clusterSet, final ClusterSetInfo clusterSetInfo,
- List clusters, ExecutorService executorService) {
- final Random random = new Random();
- List tasks = new ArrayList<>();
- for (final Cluster cluster : clusters) {
- tasks.add(new Runnable() {
- public void run() {
- try {
- Point point = cluster.getPoints().remove(random.nextInt(cluster.getPoints().size()));
- clusterSet.addNewClusterWithCenter(point);
- } catch (Throwable t) {
- log.warn("Error Splitting clusters (2)", t);
- }
- }
- });
- }
-
- MultiThreadUtils.parallelTasks(tasks, executorService);
- }
-
- public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y, int...dimensions){
- val op = createDistanceFunctionOp(distanceFunction, x, y);
- op.setDimensions(dimensions);
- return op;
- }
-
- public static ReduceOp createDistanceFunctionOp(Distance distanceFunction, INDArray x, INDArray y){
- switch (distanceFunction){
- case COSINE_DISTANCE:
- return new CosineDistance(x,y);
- case COSINE_SIMILARITY:
- return new CosineSimilarity(x,y);
- case DOT:
- return new Dot(x,y);
- case EUCLIDEAN:
- return new EuclideanDistance(x,y);
- case JACCARD:
- return new JaccardDistance(x,y);
- case MANHATTAN:
- return new ManhattanDistance(x,y);
- default:
- throw new IllegalStateException("Unknown distance function: " + distanceFunction);
- }
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java
deleted file mode 100644
index 14147b004..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/Point.java
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import lombok.AccessLevel;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.UUID;
-
-/**
- *
- */
-@Data
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-public class Point implements Serializable {
-
- private static final long serialVersionUID = -6658028541426027226L;
-
- private String id = UUID.randomUUID().toString();
- private String label;
- private INDArray array;
-
-
- /**
- *
- * @param array
- */
- public Point(INDArray array) {
- super();
- this.array = array;
- }
-
- /**
- *
- * @param id
- * @param array
- */
- public Point(String id, INDArray array) {
- super();
- this.id = id;
- this.array = array;
- }
-
- public Point(String id, String label, double[] data) {
- this(id, label, Nd4j.create(data));
- }
-
- public Point(String id, String label, INDArray array) {
- super();
- this.id = id;
- this.label = label;
- this.array = array;
- }
-
-
- /**
- *
- * @param matrix
- * @return
- */
- public static List toPoints(INDArray matrix) {
- List arr = new ArrayList<>(matrix.rows());
- for (int i = 0; i < matrix.rows(); i++) {
- arr.add(new Point(matrix.getRow(i)));
- }
-
- return arr;
- }
-
- /**
- *
- * @param vectors
- * @return
- */
- public static List toPoints(List vectors) {
- List points = new ArrayList<>();
- for (INDArray vector : vectors)
- points.add(new Point(vector));
- return points;
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java
deleted file mode 100644
index 6951b4a03..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/PointClassification.java
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.cluster;
-
-import lombok.AccessLevel;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.io.Serializable;
-
-@Data
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-@AllArgsConstructor
-public class PointClassification implements Serializable {
-
- private Cluster cluster;
- private double distanceFromCenter;
- private boolean newLocation;
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java
deleted file mode 100644
index 852a58920..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ClusteringAlgorithmCondition.java
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.condition;
-
-import org.deeplearning4j.clustering.iteration.IterationHistory;
-
-/**
- *
- */
-public interface ClusteringAlgorithmCondition {
-
- /**
- *
- * @param iterationHistory
- * @return
- */
- boolean isSatisfied(IterationHistory iterationHistory);
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java
deleted file mode 100644
index 6c2659f60..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/ConvergenceCondition.java
+++ /dev/null
@@ -1,69 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.condition;
-
-import lombok.AccessLevel;
-import lombok.AllArgsConstructor;
-import lombok.NoArgsConstructor;
-import org.deeplearning4j.clustering.iteration.IterationHistory;
-import org.nd4j.linalg.indexing.conditions.Condition;
-import org.nd4j.linalg.indexing.conditions.LessThan;
-
-import java.io.Serializable;
-
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-@AllArgsConstructor(access = AccessLevel.PROTECTED)
-public class ConvergenceCondition implements ClusteringAlgorithmCondition, Serializable {
-
- private Condition convergenceCondition;
- private double pointsDistributionChangeRate;
-
-
- /**
- *
- * @param pointsDistributionChangeRate
- * @return
- */
- public static ConvergenceCondition distributionVariationRateLessThan(double pointsDistributionChangeRate) {
- Condition condition = new LessThan(pointsDistributionChangeRate);
- return new ConvergenceCondition(condition, pointsDistributionChangeRate);
- }
-
-
- /**
- *
- * @param iterationHistory
- * @return
- */
- public boolean isSatisfied(IterationHistory iterationHistory) {
- int iterationCount = iterationHistory.getIterationCount();
- if (iterationCount <= 1)
- return false;
-
- double variation = iterationHistory.getMostRecentClusterSetInfo().getPointLocationChange().get();
- variation /= iterationHistory.getMostRecentClusterSetInfo().getPointsCount();
-
- return convergenceCondition.apply(variation);
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java
deleted file mode 100644
index 7eda7a7ec..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/FixedIterationCountCondition.java
+++ /dev/null
@@ -1,61 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.condition;
-
-import lombok.AccessLevel;
-import lombok.NoArgsConstructor;
-import org.deeplearning4j.clustering.iteration.IterationHistory;
-import org.nd4j.linalg.indexing.conditions.Condition;
-import org.nd4j.linalg.indexing.conditions.GreaterThanOrEqual;
-
-import java.io.Serializable;
-
-/**
- *
- */
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-public class FixedIterationCountCondition implements ClusteringAlgorithmCondition, Serializable {
-
- private Condition iterationCountCondition;
-
- protected FixedIterationCountCondition(int initialClusterCount) {
- iterationCountCondition = new GreaterThanOrEqual(initialClusterCount);
- }
-
- /**
- *
- * @param iterationCount
- * @return
- */
- public static FixedIterationCountCondition iterationCountGreaterThan(int iterationCount) {
- return new FixedIterationCountCondition(iterationCount);
- }
-
- /**
- *
- * @param iterationHistory
- * @return
- */
- public boolean isSatisfied(IterationHistory iterationHistory) {
- return iterationCountCondition.apply(iterationHistory == null ? 0 : iterationHistory.getIterationCount());
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java
deleted file mode 100644
index ff91dd7eb..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/condition/VarianceVariationCondition.java
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.condition;
-
-import lombok.AccessLevel;
-import lombok.AllArgsConstructor;
-import lombok.NoArgsConstructor;
-import org.deeplearning4j.clustering.iteration.IterationHistory;
-import org.nd4j.linalg.indexing.conditions.Condition;
-import org.nd4j.linalg.indexing.conditions.LessThan;
-
-import java.io.Serializable;
-
-/**
- *
- */
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-@AllArgsConstructor
-public class VarianceVariationCondition implements ClusteringAlgorithmCondition, Serializable {
-
- private Condition varianceVariationCondition;
- private int period;
-
-
-
- /**
- *
- * @param varianceVariation
- * @param period
- * @return
- */
- public static VarianceVariationCondition varianceVariationLessThan(double varianceVariation, int period) {
- Condition condition = new LessThan(varianceVariation);
- return new VarianceVariationCondition(condition, period);
- }
-
-
- /**
- *
- * @param iterationHistory
- * @return
- */
- public boolean isSatisfied(IterationHistory iterationHistory) {
- if (iterationHistory.getIterationCount() <= period)
- return false;
-
- for (int i = 0, j = iterationHistory.getIterationCount(); i < period; i++) {
- double variation = iterationHistory.getIterationInfo(j - i).getClusterSetInfo()
- .getPointDistanceFromClusterVariance();
- variation -= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
- .getPointDistanceFromClusterVariance();
- variation /= iterationHistory.getIterationInfo(j - i - 1).getClusterSetInfo()
- .getPointDistanceFromClusterVariance();
-
- if (!varianceVariationCondition.apply(variation))
- return false;
- }
-
- return true;
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java
deleted file mode 100644
index 2b78ee3e8..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterInfo.java
+++ /dev/null
@@ -1,114 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.info;
-
-import lombok.Data;
-
-import java.io.Serializable;
-import java.util.*;
-import java.util.concurrent.ConcurrentHashMap;
-
-/**
- *
- */
-@Data
-public class ClusterInfo implements Serializable {
-
- private double averagePointDistanceFromCenter;
- private double maxPointDistanceFromCenter;
- private double pointDistanceFromCenterVariance;
- private double totalPointDistanceFromCenter;
- private boolean inverse;
- private Map pointDistancesFromCenter = new ConcurrentHashMap<>();
-
- public ClusterInfo(boolean inverse) {
- this(false, inverse);
- }
-
- /**
- *
- * @param threadSafe
- */
- public ClusterInfo(boolean threadSafe, boolean inverse) {
- super();
- this.inverse = inverse;
- if (threadSafe) {
- pointDistancesFromCenter = Collections.synchronizedMap(pointDistancesFromCenter);
- }
- }
-
- /**
- *
- * @return
- */
- public Set> getSortedPointDistancesFromCenter() {
- SortedSet> sortedEntries = new TreeSet<>(new Comparator>() {
- @Override
- public int compare(Map.Entry e1, Map.Entry e2) {
- int res = e1.getValue().compareTo(e2.getValue());
- return res != 0 ? res : 1;
- }
- });
- sortedEntries.addAll(pointDistancesFromCenter.entrySet());
- return sortedEntries;
- }
-
- /**
- *
- * @return
- */
- public Set> getReverseSortedPointDistancesFromCenter() {
- SortedSet> sortedEntries = new TreeSet<>(new Comparator>() {
- @Override
- public int compare(Map.Entry e1, Map.Entry e2) {
- int res = e1.getValue().compareTo(e2.getValue());
- return -(res != 0 ? res : 1);
- }
- });
- sortedEntries.addAll(pointDistancesFromCenter.entrySet());
- return sortedEntries;
- }
-
- /**
- *
- * @param maxDistance
- * @return
- */
- public List getPointsFartherFromCenterThan(double maxDistance) {
- Set> sorted = getReverseSortedPointDistancesFromCenter();
- List ids = new ArrayList<>();
- for (Map.Entry entry : sorted) {
- if (inverse && entry.getValue() < -maxDistance) {
- if (entry.getValue() < -maxDistance)
- break;
- }
-
- else if (entry.getValue() > maxDistance)
- break;
-
- ids.add(entry.getKey());
- }
- return ids;
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java
deleted file mode 100644
index 3ddfd1b25..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/info/ClusterSetInfo.java
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.info;
-
-import org.nd4j.shade.guava.collect.HashBasedTable;
-import org.nd4j.shade.guava.collect.Table;
-import org.deeplearning4j.clustering.cluster.Cluster;
-import org.deeplearning4j.clustering.cluster.ClusterSet;
-
-import java.io.Serializable;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.concurrent.atomic.AtomicInteger;
-
-public class ClusterSetInfo implements Serializable {
-
- private Map clustersInfos = new HashMap<>();
- private Table distancesBetweenClustersCenters = HashBasedTable.create();
- private AtomicInteger pointLocationChange;
- private boolean threadSafe;
- private boolean inverse;
-
- public ClusterSetInfo(boolean inverse) {
- this(inverse, false);
- }
-
- /**
- *
- * @param inverse
- * @param threadSafe
- */
- public ClusterSetInfo(boolean inverse, boolean threadSafe) {
- this.pointLocationChange = new AtomicInteger(0);
- this.threadSafe = threadSafe;
- this.inverse = inverse;
- if (threadSafe) {
- clustersInfos = Collections.synchronizedMap(clustersInfos);
- }
- }
-
-
- /**
- *
- * @param clusterSet
- * @param threadSafe
- * @return
- */
- public static ClusterSetInfo initialize(ClusterSet clusterSet, boolean threadSafe) {
- ClusterSetInfo info = new ClusterSetInfo(clusterSet.isInverse(), threadSafe);
- for (int i = 0, j = clusterSet.getClusterCount(); i < j; i++)
- info.addClusterInfo(clusterSet.getClusters().get(i).getId());
- return info;
- }
-
- public void removeClusterInfos(List clusters) {
- for (Cluster cluster : clusters) {
- clustersInfos.remove(cluster.getId());
- }
- }
-
- public ClusterInfo addClusterInfo(String clusterId) {
- ClusterInfo clusterInfo = new ClusterInfo(this.threadSafe);
- clustersInfos.put(clusterId, clusterInfo);
- return clusterInfo;
- }
-
- public ClusterInfo getClusterInfo(String clusterId) {
- return clustersInfos.get(clusterId);
- }
-
- public double getAveragePointDistanceFromClusterCenter() {
- if (clustersInfos == null || clustersInfos.isEmpty())
- return 0;
-
- double average = 0;
- for (ClusterInfo info : clustersInfos.values())
- average += info.getAveragePointDistanceFromCenter();
- return average / clustersInfos.size();
- }
-
- public double getPointDistanceFromClusterVariance() {
- if (clustersInfos == null || clustersInfos.isEmpty())
- return 0;
-
- double average = 0;
- for (ClusterInfo info : clustersInfos.values())
- average += info.getPointDistanceFromCenterVariance();
- return average / clustersInfos.size();
- }
-
- public int getPointsCount() {
- int count = 0;
- for (ClusterInfo clusterInfo : clustersInfos.values())
- count += clusterInfo.getPointDistancesFromCenter().size();
- return count;
- }
-
- public Map getClustersInfos() {
- return clustersInfos;
- }
-
- public void setClustersInfos(Map clustersInfos) {
- this.clustersInfos = clustersInfos;
- }
-
- public Table getDistancesBetweenClustersCenters() {
- return distancesBetweenClustersCenters;
- }
-
- public void setDistancesBetweenClustersCenters(Table interClusterDistances) {
- this.distancesBetweenClustersCenters = interClusterDistances;
- }
-
- public AtomicInteger getPointLocationChange() {
- return pointLocationChange;
- }
-
- public void setPointLocationChange(AtomicInteger pointLocationChange) {
- this.pointLocationChange = pointLocationChange;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java
deleted file mode 100644
index 0854e5eb1..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationHistory.java
+++ /dev/null
@@ -1,72 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.iteration;
-
-import lombok.Getter;
-import lombok.Setter;
-import org.deeplearning4j.clustering.info.ClusterSetInfo;
-
-import java.io.Serializable;
-import java.util.HashMap;
-import java.util.Map;
-
-public class IterationHistory implements Serializable {
- @Getter
- @Setter
- private Map iterationsInfos = new HashMap<>();
-
- /**
- *
- * @return
- */
- public ClusterSetInfo getMostRecentClusterSetInfo() {
- IterationInfo iterationInfo = getMostRecentIterationInfo();
- return iterationInfo == null ? null : iterationInfo.getClusterSetInfo();
- }
-
- /**
- *
- * @return
- */
- public IterationInfo getMostRecentIterationInfo() {
- return getIterationInfo(getIterationCount() - 1);
- }
-
- /**
- *
- * @return
- */
- public int getIterationCount() {
- return getIterationsInfos().size();
- }
-
- /**
- *
- * @param iterationIdx
- * @return
- */
- public IterationInfo getIterationInfo(int iterationIdx) {
- return getIterationsInfos().get(iterationIdx);
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java
deleted file mode 100644
index 0036f3c47..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/iteration/IterationInfo.java
+++ /dev/null
@@ -1,49 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.iteration;
-
-import lombok.AccessLevel;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-import org.deeplearning4j.clustering.info.ClusterSetInfo;
-
-import java.io.Serializable;
-
-@Data
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-public class IterationInfo implements Serializable {
-
- private int index;
- private ClusterSetInfo clusterSetInfo;
- private boolean strategyApplied;
-
- public IterationInfo(int index) {
- super();
- this.index = index;
- }
-
- public IterationInfo(int index, ClusterSetInfo clusterSetInfo) {
- super();
- this.index = index;
- this.clusterSetInfo = clusterSetInfo;
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java
deleted file mode 100644
index c3e0bc418..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/HyperRect.java
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.kdtree;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-
-import java.io.Serializable;
-
-public class HyperRect implements Serializable {
-
- //private List points;
- private float[] lowerEnds;
- private float[] higherEnds;
- private INDArray lowerEndsIND;
- private INDArray higherEndsIND;
-
- public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
- this.lowerEnds = new float[lowerEndsIn.length];
- this.higherEnds = new float[lowerEndsIn.length];
- System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
- System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
- lowerEndsIND = Nd4j.createFromArray(lowerEnds);
- higherEndsIND = Nd4j.createFromArray(higherEnds);
- }
-
- public HyperRect(float[] point) {
- this(point, point);
- }
-
- public HyperRect(Pair ends) {
- this(ends.getFirst(), ends.getSecond());
- }
-
-
- public void enlargeTo(INDArray point) {
- float[] pointAsArray = point.toFloatVector();
- for (int i = 0; i < lowerEnds.length; i++) {
- float p = pointAsArray[i];
- if (lowerEnds[i] > p)
- lowerEnds[i] = p;
- else if (higherEnds[i] < p)
- higherEnds[i] = p;
- }
- }
-
- public static Pair point(INDArray vector) {
- Pair ret = new Pair<>();
- float[] curr = new float[(int)vector.length()];
- for (int i = 0; i < vector.length(); i++) {
- curr[i] = vector.getFloat(i);
- }
- ret.setFirst(curr);
- ret.setSecond(curr);
- return ret;
- }
-
-
- /*public List contains(INDArray hPoint) {
- List ret = new ArrayList<>();
- for (int i = 0; i < hPoint.length(); i++) {
- ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
- higherEnds[i] >= hPoint.getDouble(i));
- }
- return ret;
- }*/
-
- public double minDistance(INDArray hPoint, INDArray output) {
- Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
- return output.getFloat(0);
-
- /*double ret = 0.0;
- double[] pointAsArray = hPoint.toDoubleVector();
- for (int i = 0; i < pointAsArray.length; i++) {
- double p = pointAsArray[i];
- if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
- if (p < lowerEnds[i])
- ret += Math.pow((p - lowerEnds[i]), 2);
- else
- ret += Math.pow((p - higherEnds[i]), 2);
- }
- }
- ret = Math.pow(ret, 0.5);
- return ret;*/
- }
-
- public HyperRect getUpper(INDArray hPoint, int desc) {
- //Interval interval = points.get(desc);
- float higher = higherEnds[desc];
- float d = hPoint.getFloat(desc);
- if (higher < d)
- return null;
- HyperRect ret = new HyperRect(lowerEnds,higherEnds);
- if (ret.lowerEnds[desc] < d)
- ret.lowerEnds[desc] = d;
- return ret;
- }
-
- public HyperRect getLower(INDArray hPoint, int desc) {
- //Interval interval = points.get(desc);
- float lower = lowerEnds[desc];
- float d = hPoint.getFloat(desc);
- if (lower > d)
- return null;
- HyperRect ret = new HyperRect(lowerEnds,higherEnds);
- //Interval i2 = ret.points.get(desc);
- if (ret.higherEnds[desc] > d)
- ret.higherEnds[desc] = d;
- return ret;
- }
-
- @Override
- public String toString() {
- String retVal = "";
- retVal += "[";
- for (int i = 0; i < lowerEnds.length; ++i) {
- retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
- }
- retVal += "]";
- return retVal;
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java
deleted file mode 100644
index fd77c8342..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kdtree/KDTree.java
+++ /dev/null
@@ -1,370 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.kdtree;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.reduce.bool.Any;
-import org.nd4j.linalg.api.ops.impl.reduce3.EuclideanDistance;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-
-public class KDTree implements Serializable {
-
- private KDNode root;
- private int dims = 100;
- public final static int GREATER = 1;
- public final static int LESS = 0;
- private int size = 0;
- private HyperRect rect;
-
- public KDTree(int dims) {
- this.dims = dims;
- }
-
- /**
- * Insert a point in to the tree
- * @param point the point to insert
- */
- public void insert(INDArray point) {
- if (!point.isVector() || point.length() != dims)
- throw new IllegalArgumentException("Point must be a vector of length " + dims);
-
- if (root == null) {
- root = new KDNode(point);
- rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
- } else {
- int disc = 0;
- KDNode node = root;
- KDNode insert = new KDNode(point);
- int successor;
- while (true) {
- //exactly equal
- INDArray pt = node.getPoint();
- INDArray countEq = Nd4j.getExecutioner().execAndReturn(new Any(pt.neq(point))).z();
- if (countEq.getInt(0) == 0) {
- return;
- } else {
- successor = successor(node, point, disc);
- KDNode child;
- if (successor < 1)
- child = node.getLeft();
- else
- child = node.getRight();
- if (child == null)
- break;
- disc = (disc + 1) % dims;
- node = child;
- }
- }
-
- if (successor < 1)
- node.setLeft(insert);
-
- else
- node.setRight(insert);
-
- rect.enlargeTo(point);
- insert.setParent(node);
- }
- size++;
-
- }
-
-
- public INDArray delete(INDArray point) {
- KDNode node = root;
- int _disc = 0;
- while (node != null) {
- if (node.point == point)
- break;
- int successor = successor(node, point, _disc);
- if (successor < 1)
- node = node.getLeft();
- else
- node = node.getRight();
- _disc = (_disc + 1) % dims;
- }
-
- if (node != null) {
- if (node == root) {
- root = delete(root, _disc);
- } else
- node = delete(node, _disc);
- size--;
- if (size == 1) {
- rect = new HyperRect(HyperRect.point(point));
- } else if (size == 0)
- rect = null;
-
- }
- return node.getPoint();
- }
-
- // Share this data for recursive calls of "knn"
- private float currentDistance;
- private INDArray currentPoint;
- private INDArray minDistance = Nd4j.scalar(0.f);
-
-
- public List> knn(INDArray point, float distance) {
- List> best = new ArrayList<>();
- currentDistance = distance;
- currentPoint = point;
- knn(root, rect, best, 0);
- Collections.sort(best, new Comparator>() {
- @Override
- public int compare(Pair o1, Pair o2) {
- return Float.compare(o1.getKey(), o2.getKey());
- }
- });
-
- return best;
- }
-
-
- private void knn(KDNode node, HyperRect rect, List> best, int _disc) {
- if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
- return;
- int _discNext = (_disc + 1) % dims;
- float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
- .floatValue();
-
- if (distance <= currentDistance) {
- best.add(Pair.of(distance, node.getPoint()));
- }
-
- HyperRect lower = rect.getLower(node.point, _disc);
- HyperRect upper = rect.getUpper(node.point, _disc);
- knn(node.getLeft(), lower, best, _discNext);
- knn(node.getRight(), upper, best, _discNext);
- }
-
- /**
- * Query for nearest neighbor. Returns the distance and point
- * @param point the point to query for
- * @return
- */
- public Pair nn(INDArray point) {
- return nn(root, point, rect, Double.POSITIVE_INFINITY, null, 0);
- }
-
-
- private Pair nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
- int _disc) {
- if (node == null || rect.minDistance(point, minDistance) > dist)
- return Pair.of(Double.POSITIVE_INFINITY, null);
-
- int _discNext = (_disc + 1) % dims;
- double dist2 = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point, Nd4j.zeros(point.dataType(), point.shape()))).getFinalResult().doubleValue();
- if (dist2 < dist) {
- best = node.getPoint();
- dist = dist2;
- }
-
- HyperRect lower = rect.getLower(node.point, _disc);
- HyperRect upper = rect.getUpper(node.point, _disc);
-
- if (point.getDouble(_disc) < node.point.getDouble(_disc)) {
- Pair left = nn(node.getLeft(), point, lower, dist, best, _discNext);
- Pair right = nn(node.getRight(), point, upper, dist, best, _discNext);
- if (left.getKey() < dist)
- return left;
- else if (right.getKey() < dist)
- return right;
-
- } else {
- Pair left = nn(node.getRight(), point, upper, dist, best, _discNext);
- Pair right = nn(node.getLeft(), point, lower, dist, best, _discNext);
- if (left.getKey() < dist)
- return left;
- else if (right.getKey() < dist)
- return right;
- }
-
- return Pair.of(dist, best);
-
- }
-
- private KDNode delete(KDNode delete, int _disc) {
- if (delete.getLeft() != null && delete.getRight() != null) {
- if (delete.getParent() != null) {
- if (delete.getParent().getLeft() == delete)
- delete.getParent().setLeft(null);
- else
- delete.getParent().setRight(null);
-
- }
- return null;
- }
-
- int disc = _disc;
- _disc = (_disc + 1) % dims;
- Pair qd = null;
- if (delete.getRight() != null) {
- qd = min(delete.getRight(), disc, _disc);
- } else if (delete.getLeft() != null)
- qd = max(delete.getLeft(), disc, _disc);
- if (qd == null) {// is leaf
- return null;
- }
- delete.point = qd.getKey().point;
- KDNode qFather = qd.getKey().getParent();
- if (qFather.getLeft() == qd.getKey()) {
- qFather.setLeft(delete(qd.getKey(), disc));
- } else if (qFather.getRight() == qd.getKey()) {
- qFather.setRight(delete(qd.getKey(), disc));
-
- }
-
- return delete;
-
-
- }
-
-
- private Pair max(KDNode node, int disc, int _disc) {
- int discNext = (_disc + 1) % dims;
- if (_disc == disc) {
- KDNode child = node.getLeft();
- if (child != null) {
- return max(child, disc, discNext);
- }
- } else if (node.getLeft() != null || node.getRight() != null) {
- Pair left = null, right = null;
- if (node.getLeft() != null)
- left = max(node.getLeft(), disc, discNext);
- if (node.getRight() != null)
- right = max(node.getRight(), disc, discNext);
- if (left != null && right != null) {
- double pointLeft = left.getKey().getPoint().getDouble(disc);
- double pointRight = right.getKey().getPoint().getDouble(disc);
- if (pointLeft > pointRight)
- return left;
- else
- return right;
- } else if (left != null)
- return left;
- else
- return right;
- }
-
- return Pair.of(node, _disc);
- }
-
-
-
- private Pair min(KDNode node, int disc, int _disc) {
- int discNext = (_disc + 1) % dims;
- if (_disc == disc) {
- KDNode child = node.getLeft();
- if (child != null) {
- return min(child, disc, discNext);
- }
- } else if (node.getLeft() != null || node.getRight() != null) {
- Pair left = null, right = null;
- if (node.getLeft() != null)
- left = min(node.getLeft(), disc, discNext);
- if (node.getRight() != null)
- right = min(node.getRight(), disc, discNext);
- if (left != null && right != null) {
- double pointLeft = left.getKey().getPoint().getDouble(disc);
- double pointRight = right.getKey().getPoint().getDouble(disc);
- if (pointLeft < pointRight)
- return left;
- else
- return right;
- } else if (left != null)
- return left;
- else
- return right;
- }
-
- return Pair.of(node, _disc);
- }
-
- /**
- * The number of elements in the tree
- * @return the number of elements in the tree
- */
- public int size() {
- return size;
- }
-
- private int successor(KDNode node, INDArray point, int disc) {
- for (int i = disc; i < dims; i++) {
- double pointI = point.getDouble(i);
- double nodePointI = node.getPoint().getDouble(i);
- if (pointI < nodePointI)
- return LESS;
- else if (pointI > nodePointI)
- return GREATER;
-
- }
-
- throw new IllegalStateException("Point is equal!");
- }
-
-
- private static class KDNode {
- private INDArray point;
- private KDNode left, right, parent;
-
- public KDNode(INDArray point) {
- this.point = point;
- }
-
- public INDArray getPoint() {
- return point;
- }
-
- public KDNode getLeft() {
- return left;
- }
-
- public void setLeft(KDNode left) {
- this.left = left;
- }
-
- public KDNode getRight() {
- return right;
- }
-
- public void setRight(KDNode right) {
- this.right = right;
- }
-
- public KDNode getParent() {
- return parent;
- }
-
- public void setParent(KDNode parent) {
- this.parent = parent;
- }
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java
deleted file mode 100755
index 00b5bb3e9..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/kmeans/KMeansClustering.java
+++ /dev/null
@@ -1,109 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.kmeans;
-
-import org.deeplearning4j.clustering.algorithm.BaseClusteringAlgorithm;
-import org.deeplearning4j.clustering.algorithm.Distance;
-import org.deeplearning4j.clustering.strategy.ClusteringStrategy;
-import org.deeplearning4j.clustering.strategy.FixedClusterCountStrategy;
-
-
-public class KMeansClustering extends BaseClusteringAlgorithm {
-
- private static final long serialVersionUID = 8476951388145944776L;
- private static final double VARIATION_TOLERANCE= 1e-4;
-
-
- /**
- *
- * @param clusteringStrategy
- */
- protected KMeansClustering(ClusteringStrategy clusteringStrategy, boolean useKMeansPlusPlus) {
- super(clusteringStrategy, useKMeansPlusPlus);
- }
-
- /**
- * Setup a kmeans instance
- * @param clusterCount the number of clusters
- * @param maxIterationCount the max number of iterations
- * to run kmeans
- * @param distanceFunction the distance function to use for grouping
- * @return
- */
- public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction,
- boolean inverse, boolean useKMeansPlusPlus) {
- ClusteringStrategy clusteringStrategy =
- FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse);
- clusteringStrategy.endWhenIterationCountEquals(maxIterationCount);
- return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
- }
-
- /**
- *
- * @param clusterCount
- * @param minDistributionVariationRate
- * @param distanceFunction
- * @param allowEmptyClusters
- * @return
- */
- public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
- boolean inverse, boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
- ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, inverse)
- .endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
- return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
- }
-
-
- /**
- * Setup a kmeans instance
- * @param clusterCount the number of clusters
- * @param maxIterationCount the max number of iterations
- * to run kmeans
- * @param distanceFunction the distance function to use for grouping
- * @return
- */
- public static KMeansClustering setup(int clusterCount, int maxIterationCount, Distance distanceFunction, boolean useKMeansPlusPlus) {
- return setup(clusterCount, maxIterationCount, distanceFunction, false, useKMeansPlusPlus);
- }
-
- /**
- *
- * @param clusterCount
- * @param minDistributionVariationRate
- * @param distanceFunction
- * @param allowEmptyClusters
- * @return
- */
- public static KMeansClustering setup(int clusterCount, double minDistributionVariationRate, Distance distanceFunction,
- boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
- ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
- clusteringStrategy.endWhenDistributionVariationRateLessThan(minDistributionVariationRate);
- return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
- }
-
- public static KMeansClustering setup(int clusterCount, Distance distanceFunction,
- boolean allowEmptyClusters, boolean useKMeansPlusPlus) {
- ClusteringStrategy clusteringStrategy = FixedClusterCountStrategy.setup(clusterCount, distanceFunction, false);
- clusteringStrategy.endWhenDistributionVariationRateLessThan(VARIATION_TOLERANCE);
- return new KMeansClustering(clusteringStrategy, useKMeansPlusPlus);
- }
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java
deleted file mode 100644
index b9fbffa7a..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/LSH.java
+++ /dev/null
@@ -1,88 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.lsh;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-
-public interface LSH {
-
- /**
- * Returns an instance of the distance measure associated to the LSH family of this implementation.
- * Beware, hashing families and their amplification constructs are distance-specific.
- */
- String getDistanceMeasure();
-
- /**
- * Returns the size of a hash compared against in one hashing bucket, corresponding to an AND construction
- *
- * denoting hashLength by h,
- * amplifies a (d1, d2, p1, p2) hash family into a
- * (d1, d2, p1^h, p2^h)-sensitive one (match probability is decreasing with h)
- *
- * @return the length of the hash in the AND construction used by this index
- */
- int getHashLength();
-
- /**
- *
- * denoting numTables by n,
- * amplifies a (d1, d2, p1, p2) hash family into a
- * (d1, d2, (1-p1^n), (1-p2^n))-sensitive one (match probability is increasing with n)
- *
- * @return the # of hash tables in the OR construction used by this index
- */
- int getNumTables();
-
- /**
- * @return The dimension of the index vectors and queries
- */
- int getInDimension();
-
- /**
- * Populates the index with data vectors.
- * @param data the vectors to index
- */
- void makeIndex(INDArray data);
-
- /**
- * Returns the set of all vectors that could approximately be considered negihbors of the query,
- * without selection on the basis of distance or number of neighbors.
- * @param query a vector to find neighbors for
- * @return its approximate neighbors, unfiltered
- */
- INDArray bucket(INDArray query);
-
- /**
- * Returns the approximate neighbors within a distance bound.
- * @param query a vector to find neighbors for
- * @param maxRange the maximum distance between results and the query
- * @return approximate neighbors within the distance bounds
- */
- INDArray search(INDArray query, double maxRange);
-
- /**
- * Returns the approximate neighbors within a k-closest bound
- * @param query a vector to find neighbors for
- * @param k the maximum number of closest neighbors to return
- * @return at most k neighbors of the query, ordered by increasing distance
- */
- INDArray search(INDArray query, int k);
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java
deleted file mode 100644
index 7b9873d73..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/lsh/RandomProjectionLSH.java
+++ /dev/null
@@ -1,227 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.lsh;
-
-import lombok.Getter;
-import lombok.val;
-import org.nd4j.common.base.Preconditions;
-import org.nd4j.linalg.api.buffer.DataType;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo;
-import org.nd4j.linalg.api.ops.impl.transforms.same.Sign;
-import org.nd4j.linalg.api.ops.random.impl.GaussianDistribution;
-import org.nd4j.linalg.api.rng.Random;
-import org.nd4j.linalg.exception.ND4JIllegalStateException;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.linalg.indexing.BooleanIndexing;
-import org.nd4j.linalg.indexing.conditions.Conditions;
-import org.nd4j.linalg.ops.transforms.Transforms;
-
-import java.util.Arrays;
-
-
-public class RandomProjectionLSH implements LSH {
-
- @Override
- public String getDistanceMeasure(){
- return "cosinedistance";
- }
-
- @Getter private int hashLength;
-
- @Getter private int numTables;
-
- @Getter private int inDimension;
-
-
- @Getter private double radius;
-
- INDArray randomProjection;
-
- INDArray index;
-
- INDArray indexData;
-
-
- private INDArray gaussianRandomMatrix(int[] shape, Random rng){
- INDArray res = Nd4j.create(shape);
-
- GaussianDistribution op1 = new GaussianDistribution(res, 0.0, 1.0 / Math.sqrt(shape[0]));
-
- Nd4j.getExecutioner().exec(op1, rng);
- return res;
- }
-
- public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius){
- this(hashLength, numTables, inDimension, radius, Nd4j.getRandom());
- }
-
- /**
- * Creates a locality-sensitive hashing index for the cosine distance,
- * a (d1, d2, (180 − d1)/180,(180 − d2)/180)-sensitive hash family before amplification
- *
- * @param hashLength the length of the compared hash in an AND construction,
- * @param numTables the entropy-equivalent of a nb of hash tables in an OR construction, implemented here with the multiple
- * probes of Panigraphi (op. cit).
- * @param inDimension the dimendionality of the points being indexed
- * @param radius the radius of points to generate probes for. Instead of using multiple physical hash tables in an OR construction
- * @param rng a Random object to draw samples from
- */
- public RandomProjectionLSH(int hashLength, int numTables, int inDimension, double radius, Random rng){
- this.hashLength = hashLength;
- this.numTables = numTables;
- this.inDimension = inDimension;
- this.radius = radius;
- randomProjection = gaussianRandomMatrix(new int[]{inDimension, hashLength}, rng);
- }
-
- /**
- * This picks uniformaly distributed random points on the unit of a sphere using the method of:
- *
- * An efficient method for generating uniformly distributed points on the surface of an n-dimensional sphere
- * JS Hicks, RF Wheeling - Communications of the ACM, 1959
- * @param data a query to generate multiple probes for
- * @return `numTables`
- */
- public INDArray entropy(INDArray data){
-
- INDArray data2 =
- Nd4j.getExecutioner().exec(new GaussianDistribution(Nd4j.create(numTables, inDimension), radius));
-
- INDArray norms = Nd4j.norm2(data2.dup(), -1);
-
- Preconditions.checkState(norms.rank() == 1 && norms.size(0) == numTables, "Expected norm2 to have shape [%s], is %ndShape", norms.size(0), norms);
-
- data2.diviColumnVector(norms);
- data2.addiRowVector(data);
- return data2;
- }
-
- /**
- * Returns hash values for a particular query
- * @param data a query vector
- * @return its hashed value
- */
- public INDArray hash(INDArray data) {
- if (data.shape()[1] != inDimension){
- throw new ND4JIllegalStateException(
- String.format("Invalid shape: Requested INDArray shape %s, this table expects dimension %d",
- Arrays.toString(data.shape()), inDimension));
- }
- INDArray projected = data.mmul(randomProjection);
- INDArray res = Nd4j.getExecutioner().exec(new Sign(projected));
- return res;
- }
-
- /**
- * Populates the index. Beware, not incremental, any further call replaces the index instead of adding to it.
- * @param data the vectors to index
- */
- @Override
- public void makeIndex(INDArray data) {
- index = hash(data);
- indexData = data;
- }
-
- // data elements in the same bucket as the query, without entropy
- INDArray rawBucketOf(INDArray query){
- INDArray pattern = hash(query);
-
- INDArray res = Nd4j.zeros(DataType.BOOL, index.shape());
- Nd4j.getExecutioner().exec(new BroadcastEqualTo(index, pattern, res, -1));
- return res.castTo(Nd4j.defaultFloatingPointType()).min(-1);
- }
-
- @Override
- public INDArray bucket(INDArray query) {
- INDArray queryRes = rawBucketOf(query);
-
- if(numTables > 1) {
- INDArray entropyQueries = entropy(query);
-
- // loop, addi + conditionalreplace -> poor man's OR function
- for (int i = 0; i < numTables; i++) {
- INDArray row = entropyQueries.getRow(i, true);
- queryRes.addi(rawBucketOf(row));
- }
- BooleanIndexing.replaceWhere(queryRes, 1.0, Conditions.greaterThan(0.0));
- }
-
- return queryRes;
- }
-
- // data elements in the same entropy bucket as the query,
- INDArray bucketData(INDArray query){
- INDArray mask = bucket(query);
- int nRes = mask.sum(0).getInt(0);
- INDArray res = Nd4j.create(new int[] {nRes, inDimension});
- int j = 0;
- for (int i = 0; i < nRes; i++){
- while (mask.getInt(j) == 0 && j < mask.length() - 1) {
- j += 1;
- }
- if (mask.getInt(j) == 1) res.putRow(i, indexData.getRow(j));
- j += 1;
- }
- return res;
- }
-
- @Override
- public INDArray search(INDArray query, double maxRange) {
- if (maxRange < 0)
- throw new IllegalArgumentException("ANN search should have a positive maximum search radius");
-
- INDArray bucketData = bucketData(query);
- INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
- INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
-
- INDArray shuffleIndexes = idxs[0];
- INDArray sortedDistances = idxs[1];
- int accepted = 0;
- while (accepted < sortedDistances.length() && sortedDistances.getInt(accepted) <= maxRange) accepted +=1;
-
- INDArray res = Nd4j.create(new int[] {accepted, inDimension});
- for(int i = 0; i < accepted; i++){
- res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
- }
- return res;
- }
-
- @Override
- public INDArray search(INDArray query, int k) {
- if (k < 1)
- throw new IllegalArgumentException("An ANN search for k neighbors should at least seek one neighbor");
-
- INDArray bucketData = bucketData(query);
- INDArray distances = Transforms.allCosineDistances(bucketData, query, -1);
- INDArray[] idxs = Nd4j.sortWithIndices(distances, -1, true);
-
- INDArray shuffleIndexes = idxs[0];
- INDArray sortedDistances = idxs[1];
- val accepted = Math.min(k, sortedDistances.shape()[1]);
-
- INDArray res = Nd4j.create(accepted, inDimension);
- for(int i = 0; i < accepted; i++){
- res.putRow(i, bucketData.getRow(shuffleIndexes.getInt(i)));
- }
- return res;
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java
deleted file mode 100644
index b65571de3..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimization.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.optimisation;
-
-import lombok.AccessLevel;
-import lombok.AllArgsConstructor;
-import lombok.Data;
-import lombok.NoArgsConstructor;
-
-import java.io.Serializable;
-
-@Data
-@NoArgsConstructor(access = AccessLevel.PROTECTED)
-@AllArgsConstructor
-public class ClusteringOptimization implements Serializable {
-
- private ClusteringOptimizationType type;
- private double value;
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java
deleted file mode 100644
index a2220010e..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/optimisation/ClusteringOptimizationType.java
+++ /dev/null
@@ -1,28 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.optimisation;
-
-/**
- *
- */
-public enum ClusteringOptimizationType {
- MINIMIZE_AVERAGE_POINT_TO_CENTER_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_CENTER_DISTANCE, MINIMIZE_AVERAGE_POINT_TO_POINT_DISTANCE, MINIMIZE_MAXIMUM_POINT_TO_POINT_DISTANCE, MINIMIZE_PER_CLUSTER_POINT_COUNT
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java
deleted file mode 100644
index cb82b6f87..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/Cell.java
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.quadtree;
-
-import org.nd4j.linalg.api.ndarray.INDArray;
-
-import java.io.Serializable;
-
-public class Cell implements Serializable {
- private double x, y, hw, hh;
-
- public Cell(double x, double y, double hw, double hh) {
- this.x = x;
- this.y = y;
- this.hw = hw;
- this.hh = hh;
- }
-
- /**
- * Whether the given point is contained
- * within this cell
- * @param point the point to check
- * @return true if the point is contained, false otherwise
- */
- public boolean containsPoint(INDArray point) {
- double first = point.getDouble(0), second = point.getDouble(1);
- return x - hw <= first && x + hw >= first && y - hh <= second && y + hh >= second;
- }
-
- @Override
- public boolean equals(Object o) {
- if (this == o)
- return true;
- if (!(o instanceof Cell))
- return false;
-
- Cell cell = (Cell) o;
-
- if (Double.compare(cell.hh, hh) != 0)
- return false;
- if (Double.compare(cell.hw, hw) != 0)
- return false;
- if (Double.compare(cell.x, x) != 0)
- return false;
- return Double.compare(cell.y, y) == 0;
-
- }
-
- @Override
- public int hashCode() {
- int result;
- long temp;
- temp = Double.doubleToLongBits(x);
- result = (int) (temp ^ (temp >>> 32));
- temp = Double.doubleToLongBits(y);
- result = 31 * result + (int) (temp ^ (temp >>> 32));
- temp = Double.doubleToLongBits(hw);
- result = 31 * result + (int) (temp ^ (temp >>> 32));
- temp = Double.doubleToLongBits(hh);
- result = 31 * result + (int) (temp ^ (temp >>> 32));
- return result;
- }
-
- public double getX() {
- return x;
- }
-
- public void setX(double x) {
- this.x = x;
- }
-
- public double getY() {
- return y;
- }
-
- public void setY(double y) {
- this.y = y;
- }
-
- public double getHw() {
- return hw;
- }
-
- public void setHw(double hw) {
- this.hw = hw;
- }
-
- public double getHh() {
- return hh;
- }
-
- public void setHh(double hh) {
- this.hh = hh;
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java
deleted file mode 100644
index 20d216b44..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/quadtree/QuadTree.java
+++ /dev/null
@@ -1,383 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.quadtree;
-
-import org.nd4j.shade.guava.util.concurrent.AtomicDouble;
-import org.apache.commons.math3.util.FastMath;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-import java.io.Serializable;
-
-import static java.lang.Math.max;
-
-public class QuadTree implements Serializable {
- private QuadTree parent, northWest, northEast, southWest, southEast;
- private boolean isLeaf = true;
- private int size, cumSize;
- private Cell boundary;
- static final int QT_NO_DIMS = 2;
- static final int QT_NODE_CAPACITY = 1;
- private INDArray buf = Nd4j.create(QT_NO_DIMS);
- private INDArray data, centerOfMass = Nd4j.create(QT_NO_DIMS);
- private int[] index = new int[QT_NODE_CAPACITY];
-
-
- /**
- * Pass in a matrix
- * @param data
- */
- public QuadTree(INDArray data) {
- INDArray meanY = data.mean(0);
- INDArray minY = data.min(0);
- INDArray maxY = data.max(0);
- init(data, meanY.getDouble(0), meanY.getDouble(1),
- max(maxY.getDouble(0) - meanY.getDouble(0), meanY.getDouble(0) - minY.getDouble(0))
- + Nd4j.EPS_THRESHOLD,
- max(maxY.getDouble(1) - meanY.getDouble(1), meanY.getDouble(1) - minY.getDouble(1))
- + Nd4j.EPS_THRESHOLD);
- fill();
- }
-
- public QuadTree(QuadTree parent, INDArray data, Cell boundary) {
- this.parent = parent;
- this.boundary = boundary;
- this.data = data;
-
- }
-
- public QuadTree(Cell boundary) {
- this.boundary = boundary;
- }
-
- private void init(INDArray data, double x, double y, double hw, double hh) {
- boundary = new Cell(x, y, hw, hh);
- this.data = data;
- }
-
- private void fill() {
- for (int i = 0; i < data.rows(); i++)
- insert(i);
- }
-
-
-
- /**
- * Returns the cell of this element
- *
- * @param coordinates
- * @return
- */
- protected QuadTree findIndex(INDArray coordinates) {
-
- // Compute the sector for the coordinates
- boolean left = (coordinates.getDouble(0) <= (boundary.getX() + boundary.getHw() / 2));
- boolean top = (coordinates.getDouble(1) <= (boundary.getY() + boundary.getHh() / 2));
-
- // top left
- QuadTree index = getNorthWest();
- if (left) {
- // left side
- if (!top) {
- // bottom left
- index = getSouthWest();
- }
- } else {
- // right side
- if (top) {
- // top right
- index = getNorthEast();
- } else {
- // bottom right
- index = getSouthEast();
-
- }
- }
-
- return index;
- }
-
-
- /**
- * Insert an index of the data in to the tree
- * @param newIndex the index to insert in to the tree
- * @return whether the index was inserted or not
- */
- public boolean insert(int newIndex) {
- // Ignore objects which do not belong in this quad tree
- INDArray point = data.slice(newIndex);
- if (!boundary.containsPoint(point))
- return false;
-
- cumSize++;
- double mult1 = (double) (cumSize - 1) / (double) cumSize;
- double mult2 = 1.0 / (double) cumSize;
-
- centerOfMass.muli(mult1);
- centerOfMass.addi(point.mul(mult2));
-
- // If there is space in this quad tree and it is a leaf, add the object here
- if (isLeaf() && size < QT_NODE_CAPACITY) {
- index[size] = newIndex;
- size++;
- return true;
- }
-
- //duplicate point
- if (size > 0) {
- for (int i = 0; i < size; i++) {
- INDArray compPoint = data.slice(index[i]);
- if (point.getDouble(0) == compPoint.getDouble(0) && point.getDouble(1) == compPoint.getDouble(1))
- return true;
- }
- }
-
-
-
- // If this Node has already been subdivided just add the elements to the
- // appropriate cell
- if (!isLeaf()) {
- QuadTree index = findIndex(point);
- index.insert(newIndex);
- return true;
- }
-
- if (isLeaf())
- subDivide();
-
- boolean ret = insertIntoOneOf(newIndex);
-
-
-
- return ret;
- }
-
- private boolean insertIntoOneOf(int index) {
- boolean success = false;
- success = northWest.insert(index);
- if (!success)
- success = northEast.insert(index);
- if (!success)
- success = southWest.insert(index);
- if (!success)
- success = southEast.insert(index);
- return success;
- }
-
-
- /**
- * Returns whether the tree is consistent or not
- * @return whether the tree is consistent or not
- */
- public boolean isCorrect() {
-
- for (int n = 0; n < size; n++) {
- INDArray point = data.slice(index[n]);
- if (!boundary.containsPoint(point))
- return false;
- }
-
- return isLeaf() || northWest.isCorrect() && northEast.isCorrect() && southWest.isCorrect()
- && southEast.isCorrect();
-
- }
-
-
-
- /**
- * Create four children
- * which fully divide this cell
- * into four quads of equal area
- */
- public void subDivide() {
- northWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
- boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
- northEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
- boundary.getY() - .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
- southWest = new QuadTree(this, data, new Cell(boundary.getX() - .5 * boundary.getHw(),
- boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
- southEast = new QuadTree(this, data, new Cell(boundary.getX() + .5 * boundary.getHw(),
- boundary.getY() + .5 * boundary.getHh(), .5 * boundary.getHw(), .5 * boundary.getHh()));
-
- }
-
-
- /**
- * Compute non edge forces using barnes hut
- * @param pointIndex
- * @param theta
- * @param negativeForce
- * @param sumQ
- */
- public void computeNonEdgeForces(int pointIndex, double theta, INDArray negativeForce, AtomicDouble sumQ) {
- // Make sure that we spend no time on empty nodes or self-interactions
- if (cumSize == 0 || (isLeaf() && size == 1 && index[0] == pointIndex))
- return;
-
-
- // Compute distance between point and center-of-mass
- buf.assign(data.slice(pointIndex)).subi(centerOfMass);
-
- double D = Nd4j.getBlasWrapper().dot(buf, buf);
-
- // Check whether we can use this node as a "summary"
- if (isLeaf || FastMath.max(boundary.getHh(), boundary.getHw()) / FastMath.sqrt(D) < theta) {
-
- // Compute and add t-SNE force between point and current node
- double Q = 1.0 / (1.0 + D);
- double mult = cumSize * Q;
- sumQ.addAndGet(mult);
- mult *= Q;
- negativeForce.addi(buf.mul(mult));
-
- } else {
-
- // Recursively apply Barnes-Hut to children
- northWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
- northEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
- southWest.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
- southEast.computeNonEdgeForces(pointIndex, theta, negativeForce, sumQ);
- }
- }
-
-
-
- /**
- *
- * @param rowP a vector
- * @param colP
- * @param valP
- * @param N
- * @param posF
- */
- public void computeEdgeForces(INDArray rowP, INDArray colP, INDArray valP, int N, INDArray posF) {
- if (!rowP.isVector())
- throw new IllegalArgumentException("RowP must be a vector");
-
- // Loop over all edges in the graph
- double D;
- for (int n = 0; n < N; n++) {
- for (int i = rowP.getInt(n); i < rowP.getInt(n + 1); i++) {
-
- // Compute pairwise distance and Q-value
- buf.assign(data.slice(n)).subi(data.slice(colP.getInt(i)));
-
- D = Nd4j.getBlasWrapper().dot(buf, buf);
- D = valP.getDouble(i) / D;
-
- // Sum positive force
- posF.slice(n).addi(buf.mul(D));
-
- }
- }
- }
-
-
- /**
- * The depth of the node
- * @return the depth of the node
- */
- public int depth() {
- if (isLeaf())
- return 1;
- return 1 + max(max(northWest.depth(), northEast.depth()), max(southWest.depth(), southEast.depth()));
- }
-
- public INDArray getCenterOfMass() {
- return centerOfMass;
- }
-
- public void setCenterOfMass(INDArray centerOfMass) {
- this.centerOfMass = centerOfMass;
- }
-
- public QuadTree getParent() {
- return parent;
- }
-
- public void setParent(QuadTree parent) {
- this.parent = parent;
- }
-
- public QuadTree getNorthWest() {
- return northWest;
- }
-
- public void setNorthWest(QuadTree northWest) {
- this.northWest = northWest;
- }
-
- public QuadTree getNorthEast() {
- return northEast;
- }
-
- public void setNorthEast(QuadTree northEast) {
- this.northEast = northEast;
- }
-
- public QuadTree getSouthWest() {
- return southWest;
- }
-
- public void setSouthWest(QuadTree southWest) {
- this.southWest = southWest;
- }
-
- public QuadTree getSouthEast() {
- return southEast;
- }
-
- public void setSouthEast(QuadTree southEast) {
- this.southEast = southEast;
- }
-
- public boolean isLeaf() {
- return isLeaf;
- }
-
- public void setLeaf(boolean isLeaf) {
- this.isLeaf = isLeaf;
- }
-
- public int getSize() {
- return size;
- }
-
- public void setSize(int size) {
- this.size = size;
- }
-
- public int getCumSize() {
- return cumSize;
- }
-
- public void setCumSize(int cumSize) {
- this.cumSize = cumSize;
- }
-
- public Cell getBoundary() {
- return boundary;
- }
-
- public void setBoundary(Cell boundary) {
- this.boundary = boundary;
- }
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java
deleted file mode 100644
index f814025d5..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPForest.java
+++ /dev/null
@@ -1,104 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.randomprojection;
-
-import lombok.Data;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.primitives.Pair;
-
-import java.util.ArrayList;
-import java.util.List;
-
-/**
- *
- */
-@Data
-public class RPForest {
-
- private int numTrees;
- private List trees;
- private INDArray data;
- private int maxSize = 1000;
- private String similarityFunction;
-
- /**
- * Create the rp forest with the specified number of trees
- * @param numTrees the number of trees in the forest
- * @param maxSize the max size of each tree
- * @param similarityFunction the distance function to use
- */
- public RPForest(int numTrees,int maxSize,String similarityFunction) {
- this.numTrees = numTrees;
- this.maxSize = maxSize;
- this.similarityFunction = similarityFunction;
- trees = new ArrayList<>(numTrees);
-
- }
-
-
- /**
- * Build the trees from the given dataset
- * @param x the input dataset (should be a 2d matrix)
- */
- public void fit(INDArray x) {
- this.data = x;
- for(int i = 0; i < numTrees; i++) {
- RPTree tree = new RPTree(data.columns(),maxSize,similarityFunction);
- tree.buildTree(x);
- trees.add(tree);
- }
- }
-
- /**
- * Get all candidates relative to a specific datapoint.
- * @param input
- * @return
- */
- public INDArray getAllCandidates(INDArray input) {
- return RPUtils.getAllCandidates(input,trees,similarityFunction);
- }
-
- /**
- * Query results up to length n
- * nearest neighbors
- * @param toQuery the query item
- * @param n the number of nearest neighbors for the given data point
- * @return the indices for the nearest neighbors
- */
- public INDArray queryAll(INDArray toQuery,int n) {
- return RPUtils.queryAll(toQuery,data,trees,n,similarityFunction);
- }
-
-
- /**
- * Query all with the distances
- * sorted by index
- * @param query the query vector
- * @param numResults the number of results to return
- * @return a list of samples
- */
- public List> queryWithDistances(INDArray query, int numResults) {
- return RPUtils.queryAllWithDistances(query,this.data, trees,numResults,similarityFunction);
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java
deleted file mode 100644
index 979013797..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPHyperPlanes.java
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.randomprojection;
-
-import lombok.Data;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.factory.Nd4j;
-
-@Data
-public class RPHyperPlanes {
- private int dim;
- private INDArray wholeHyperPlane;
-
- public RPHyperPlanes(int dim) {
- this.dim = dim;
- }
-
- public INDArray getHyperPlaneAt(int depth) {
- if(wholeHyperPlane.isVector())
- return wholeHyperPlane;
- return wholeHyperPlane.slice(depth);
- }
-
-
- /**
- * Add a new random element to the hyper plane.
- */
- public void addRandomHyperPlane() {
- INDArray newPlane = Nd4j.randn(new int[] {1,dim});
- newPlane.divi(newPlane.normmaxNumber());
- if(wholeHyperPlane == null)
- wholeHyperPlane = newPlane;
- else {
- wholeHyperPlane = Nd4j.concat(0,wholeHyperPlane,newPlane);
- }
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java
deleted file mode 100644
index 9a103469e..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPNode.java
+++ /dev/null
@@ -1,48 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.randomprojection;
-
-
-import lombok.Data;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.Future;
-
-@Data
-public class RPNode {
- private int depth;
- private RPNode left,right;
- private Future leftFuture,rightFuture;
- private List indices;
- private double median;
- private RPTree tree;
-
-
- public RPNode(RPTree tree,int depth) {
- this.depth = depth;
- this.tree = tree;
- indices = new ArrayList<>();
- }
-
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java
deleted file mode 100644
index 7fbca2b90..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPTree.java
+++ /dev/null
@@ -1,130 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.randomprojection;
-
-import lombok.Builder;
-import lombok.Data;
-import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
-import org.nd4j.linalg.api.memory.enums.*;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.common.primitives.Pair;
-
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-import java.util.concurrent.ExecutorService;
-
-@Data
-public class RPTree {
- private RPNode root;
- private RPHyperPlanes rpHyperPlanes;
- private int dim;
- //also knows as leave size
- private int maxSize;
- private INDArray X;
- private String similarityFunction = "euclidean";
- private WorkspaceConfiguration workspaceConfiguration;
- private ExecutorService searchExecutor;
- private int searchWorkers;
-
- /**
- *
- * @param dim the dimension of the vectors
- * @param maxSize the max size of the leaves
- *
- */
- @Builder
- public RPTree(int dim, int maxSize,String similarityFunction) {
- this.dim = dim;
- this.maxSize = maxSize;
- rpHyperPlanes = new RPHyperPlanes(dim);
- root = new RPNode(this,0);
- this.similarityFunction = similarityFunction;
- workspaceConfiguration = WorkspaceConfiguration.builder().cyclesBeforeInitialization(1)
- .policyAllocation(AllocationPolicy.STRICT).policyLearning(LearningPolicy.FIRST_LOOP)
- .policyMirroring(MirroringPolicy.FULL).policyReset(ResetPolicy.BLOCK_LEFT)
- .policySpill(SpillPolicy.REALLOCATE).build();
-
- }
-
- /**
- *
- * @param dim the dimension of the vectors
- * @param maxSize the max size of the leaves
- *
- */
- public RPTree(int dim, int maxSize) {
- this(dim,maxSize,"euclidean");
- }
-
- /**
- * Build the tree with the given input data
- * @param x
- */
-
- public void buildTree(INDArray x) {
- this.X = x;
- for(int i = 0; i < x.rows(); i++) {
- root.getIndices().add(i);
- }
-
-
-
- RPUtils.buildTree(this,root,rpHyperPlanes,
- x,maxSize,0,similarityFunction);
- }
-
-
-
- public void addNodeAtIndex(int idx,INDArray toAdd) {
- RPNode query = RPUtils.query(root,rpHyperPlanes,toAdd,similarityFunction);
- query.getIndices().add(idx);
- }
-
-
- public List getLeaves() {
- List nodes = new ArrayList<>();
- RPUtils.scanForLeaves(nodes,getRoot());
- return nodes;
- }
-
-
- /**
- * Query all with the distances
- * sorted by index
- * @param query the query vector
- * @param numResults the number of results to return
- * @return a list of samples
- */
- public List> queryWithDistances(INDArray query, int numResults) {
- return RPUtils.queryAllWithDistances(query,X,Arrays.asList(this),numResults,similarityFunction);
- }
-
- public INDArray query(INDArray query,int numResults) {
- return RPUtils.queryAll(query,X,Arrays.asList(this),numResults,similarityFunction);
- }
-
- public List getCandidates(INDArray target) {
- return RPUtils.getCandidates(target,Arrays.asList(this),similarityFunction);
- }
-
-
-}
diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java
deleted file mode 100644
index 0bd2574e7..000000000
--- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/randomprojection/RPUtils.java
+++ /dev/null
@@ -1,481 +0,0 @@
-/*
- * ******************************************************************************
- * *
- * *
- * * This program and the accompanying materials are made available under the
- * * terms of the Apache License, Version 2.0 which is available at
- * * https://www.apache.org/licenses/LICENSE-2.0.
- * *
- * * See the NOTICE file distributed with this work for additional
- * * information regarding copyright ownership.
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * * License for the specific language governing permissions and limitations
- * * under the License.
- * *
- * * SPDX-License-Identifier: Apache-2.0
- * *****************************************************************************
- */
-
-package org.deeplearning4j.clustering.randomprojection;
-
-import org.nd4j.shade.guava.primitives.Doubles;
-import lombok.val;
-import org.nd4j.autodiff.functions.DifferentialFunction;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.ReduceOp;
-import org.nd4j.linalg.api.ops.impl.reduce3.*;
-import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
-import org.nd4j.linalg.factory.Nd4j;
-import org.nd4j.common.primitives.Pair;
-
-import java.util.*;
-
-public class RPUtils {
-
-
- private static ThreadLocal