diff --git a/datavec/datavec-camel/pom.xml b/datavec/datavec-camel/pom.xml deleted file mode 100644 index 3390242bc..000000000 --- a/datavec/datavec-camel/pom.xml +++ /dev/null @@ -1,116 +0,0 @@ - - - - - - datavec-parent - org.datavec - 1.0.0-SNAPSHOT - - - 4.0.0 - - datavec-camel - - DataVec Camel Component - http://deeplearning4j.org - - - - - org.apache.camel - camel-csv - ${camel.version} - test - - - org.datavec - datavec-api - ${project.version} - - - org.apache.camel - camel-core - ${camel.version} - - - - - - - org.apache.camel - camel-test - ${camel.version} - test - - - - - install - - - - - maven-compiler-plugin - - 1.7 - 1.7 - - - - - maven-resources-plugin - 3.0.1 - - UTF-8 - - - - - - org.apache.camel - camel-package-maven-plugin - ${camel.version} - - - prepare - - prepare-components - - generate-resources - - - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java deleted file mode 100644 index 2abb73558..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecComponent.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.CamelContext; -import org.apache.camel.Endpoint; -import org.apache.camel.impl.UriEndpointComponent; - -import java.util.Map; - -/** - * Represents the component that manages {@link DataVecEndpoint}. - */ -public class DataVecComponent extends UriEndpointComponent { - - public DataVecComponent() { - super(DataVecEndpoint.class); - } - - public DataVecComponent(CamelContext context) { - super(context, DataVecEndpoint.class); - } - - @Override - protected Endpoint createEndpoint(String uri, String remaining, Map parameters) throws Exception { - DataVecEndpoint endpoint = new DataVecEndpoint(uri, this); - setProperties(endpoint, parameters); - endpoint.setInputFormat(remaining); - return endpoint; - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java deleted file mode 100644 index e17d022ef..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecConsumer.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - - -import org.apache.camel.Exchange; -import org.apache.camel.Processor; -import org.apache.camel.impl.ScheduledPollConsumer; -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.InputFormat; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; - -/** - * The DataVec consumer. - * @author Adam Gibson - */ -public class DataVecConsumer extends ScheduledPollConsumer { - private final DataVecEndpoint endpoint; - private Class inputFormatClazz; - private Class marshallerClazz; - private InputFormat inputFormat; - private Configuration configuration; - private DataVecMarshaller marshaller; - - - public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) { - super(endpoint, processor); - this.endpoint = endpoint; - - try { - inputFormatClazz = (Class) Class.forName(endpoint.getInputFormat()); - inputFormat = inputFormatClazz.newInstance(); - marshallerClazz = (Class) Class.forName(endpoint.getInputMarshaller()); - marshaller = marshallerClazz.newInstance(); - configuration = new Configuration(); - for (String prop : endpoint.getConsumerProperties().keySet()) - configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString()); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - //stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split - protected InputSplit inputFromExchange(Exchange exchange) { - return marshaller.getSplit(exchange); - } - - @Override - protected int poll() throws Exception { - Exchange exchange = endpoint.createExchange(); - InputSplit split = inputFromExchange(exchange); - RecordReader reader = inputFormat.createReader(split, configuration); - int numMessagesPolled = 0; - while (reader.hasNext()) { - // create a message body - while (reader.hasNext()) { - exchange.getIn().setBody(reader.next()); - - try { - // send message to next processor in the route - getProcessor().process(exchange); - numMessagesPolled++; // number of messages polled - } finally { - // log exception if an exception occurred and was not handled - if (exchange.getException() != null) { - getExceptionHandler().handleException("Error processing exchange", exchange, - exchange.getException()); - } - } - } - - - } - - return numMessagesPolled; - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java deleted file mode 100644 index 930f968a1..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecEndpoint.java +++ /dev/null @@ -1,68 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import lombok.Data; -import org.apache.camel.Consumer; -import org.apache.camel.Processor; -import org.apache.camel.Producer; -import org.apache.camel.impl.DefaultEndpoint; -import org.apache.camel.spi.Metadata; -import org.apache.camel.spi.UriEndpoint; -import org.apache.camel.spi.UriParam; -import org.apache.camel.spi.UriPath; - -/** - * Represents a DataVec endpoint. - * @author Adam Gibson - */ -@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?", - consumerClass = DataVecConsumer.class, label = "datavec") -@Data -public class DataVecEndpoint extends DefaultEndpoint { - @UriPath - @Metadata(required = "true") - private String inputFormat; - @UriParam(defaultValue = "") - private String outputFormat; - @UriParam - @Metadata(required = "true") - private String inputMarshaller; - @UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter") - private String writableConverter; - - public DataVecEndpoint(String uri, DataVecComponent component) { - super(uri, component); - } - - public DataVecEndpoint(String endpointUri) { - super(endpointUri); - } - - public Producer createProducer() throws Exception { - return new DataVecProducer(this); - } - - public Consumer createConsumer(Processor processor) throws Exception { - return new DataVecConsumer(this, processor); - } - - public boolean isSingleton() { - return true; - } - -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java deleted file mode 100644 index 62eba29d7..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecMarshaller.java +++ /dev/null @@ -1,36 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; - -/** - * Marshals na exchange in to an input split - * @author Adam Gibson - */ -public interface DataVecMarshaller { - - - /** - * - * @param exchange - * @return - */ - InputSplit getSplit(Exchange exchange); - -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java deleted file mode 100644 index 45697ff1c..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/DataVecProducer.java +++ /dev/null @@ -1,109 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.Exchange; -import org.apache.camel.impl.DefaultProducer; -import org.datavec.api.conf.Configuration; -import org.datavec.api.formats.input.InputFormat; -import org.datavec.api.io.WritableConverter; -import org.datavec.api.io.converters.SelfWritableConverter; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputSplit; -import org.datavec.api.writable.Writable; - -import java.util.ArrayList; -import java.util.Collection; - - -/** - * The DataVec producer. - * Converts input records in to their final form - * based on the input split generated from - * the given exchange. - * - * @author Adam Gibson - */ -public class DataVecProducer extends DefaultProducer { - private Class inputFormatClazz; - private Class marshallerClazz; - private InputFormat inputFormat; - private Configuration configuration; - private WritableConverter writableConverter; - private DataVecMarshaller marshaller; - - - public DataVecProducer(DataVecEndpoint endpoint) { - super(endpoint); - if (endpoint.getInputFormat() != null) { - try { - inputFormatClazz = (Class) Class.forName(endpoint.getInputFormat()); - inputFormat = inputFormatClazz.newInstance(); - marshallerClazz = (Class) Class.forName(endpoint.getInputMarshaller()); - Class converterClazz = - (Class) Class.forName(endpoint.getWritableConverter()); - writableConverter = converterClazz.newInstance(); - marshaller = marshallerClazz.newInstance(); - configuration = new Configuration(); - for (String prop : endpoint.getConsumerProperties().keySet()) - configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString()); - - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - } - - - //stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split - protected InputSplit inputFromExchange(Exchange exchange) { - return marshaller.getSplit(exchange); - } - - - @Override - public void process(Exchange exchange) throws Exception { - InputSplit split = inputFromExchange(exchange); - RecordReader reader = inputFormat.createReader(split, configuration); - Collection> newRecord = new ArrayList<>(); - if (!(writableConverter instanceof SelfWritableConverter)) { - newRecord = new ArrayList<>(); - while (reader.hasNext()) { - Collection newRecordAdd = new ArrayList<>(); - // create a message body - Collection next = reader.next(); - for (Writable writable : next) { - newRecordAdd.add(writableConverter.convert(writable)); - } - - - newRecord.add(newRecordAdd); - } - } else { - while (reader.hasNext()) { - // create a message body - Collection next = reader.next(); - newRecord.add(next); - } - } - - - exchange.getIn().setBody(newRecord); - exchange.getOut().setBody(newRecord); - } -} diff --git a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java b/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java deleted file mode 100644 index 9ff8913be..000000000 --- a/datavec/datavec-camel/src/main/java/org/datavec/camel/component/csv/marshaller/ListStringInputMarshaller.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component.csv.marshaller; - -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; -import org.datavec.api.split.ListStringSplit; -import org.datavec.camel.component.DataVecMarshaller; - -import java.util.List; - -/** - * Marshals List> - * - * @author Adam Gibson - */ -public class ListStringInputMarshaller implements DataVecMarshaller { - /** - * @param exchange - * @return - */ - @Override - public InputSplit getSplit(Exchange exchange) { - List> data = (List>) exchange.getIn().getBody(); - InputSplit listSplit = new ListStringSplit(data); - return listSplit; - } -} diff --git a/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec b/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec deleted file mode 100644 index 2a324f975..000000000 --- a/datavec/datavec-camel/src/main/resources/META-INF/services/org/apache/camel/component/datavec +++ /dev/null @@ -1 +0,0 @@ -class=org.datavec.camel.component.DataVecComponent diff --git a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java b/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java deleted file mode 100644 index de4aa35cf..000000000 --- a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/DataVecComponentTest.java +++ /dev/null @@ -1,82 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.builder.RouteBuilder; -import org.apache.camel.component.mock.MockEndpoint; -import org.apache.camel.test.junit4.CamelTestSupport; -import org.apache.commons.io.FileUtils; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.records.reader.impl.csv.CSVRecordReader; -import org.datavec.api.split.FileSplit; -import org.datavec.api.writable.Writable; -import org.junit.BeforeClass; -import org.junit.ClassRule; -import org.junit.Test; -import org.junit.rules.TemporaryFolder; -import org.nd4j.linalg.io.ClassPathResource; - -import java.io.File; -import java.util.ArrayList; -import java.util.Collection; - -public class DataVecComponentTest extends CamelTestSupport { - - @ClassRule - public static TemporaryFolder testDir = new TemporaryFolder(); - private static File dir; - private static File irisFile; - - - @BeforeClass - public static void before() throws Exception { - dir = testDir.newFolder(); - File iris = new ClassPathResource("iris.dat").getFile(); - irisFile = new File(dir, "iris.dat"); - FileUtils.copyFile(iris, irisFile ); - } - - - - @Test - public void testDataVec() throws Exception { - MockEndpoint mock = getMockEndpoint("mock:result"); - //1 - mock.expectedMessageCount(1); - - RecordReader reader = new CSVRecordReader(); - reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile())); - Collection> recordAssertion = new ArrayList<>(); - while (reader.hasNext()) - recordAssertion.add(reader.next()); - mock.expectedBodiesReceived(recordAssertion); - assertMockEndpointsSatisfied(); - } - - @Override - protected RouteBuilder createRouteBuilder() throws Exception { - - - return new RouteBuilder() { - public void configure() { - from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv() - .to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter") - .to("mock:result"); - } - }; - } -} diff --git a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java b/datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java deleted file mode 100644 index 77b376d90..000000000 --- a/datavec/datavec-camel/src/test/java/org/datavec/camel/component/ListStringInputMarshaller.java +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.camel.component; - -import org.apache.camel.Exchange; -import org.datavec.api.split.InputSplit; -import org.datavec.api.split.ListStringSplit; - -import java.util.List; - -/** - * Marshals List> - * - * @author Adam Gibson - */ -public class ListStringInputMarshaller implements DataVecMarshaller { - /** - * @param exchange - * @return - */ - @Override - public InputSplit getSplit(Exchange exchange) { - List> data = (List>) exchange.getIn().getBody(); - InputSplit listSplit = new ListStringSplit(data); - return listSplit; - } -} diff --git a/datavec/datavec-camel/src/test/resources/log4j.properties b/datavec/datavec-camel/src/test/resources/log4j.properties deleted file mode 100644 index 09f815522..000000000 --- a/datavec/datavec-camel/src/test/resources/log4j.properties +++ /dev/null @@ -1,30 +0,0 @@ -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -# -# The logging properties used -# -log4j.rootLogger=INFO, out - -# uncomment the following line to turn on Camel debugging -#log4j.logger.org.apache.camel=DEBUG - -# CONSOLE appender not used by default -log4j.appender.out=org.apache.log4j.ConsoleAppender -log4j.appender.out.layout=org.apache.log4j.PatternLayout -log4j.appender.out.layout.ConversionPattern=[%30.30t] %-30.30c{1} %-5p %m%n -#log4j.appender.out.layout.ConversionPattern=%d [%-15.15t] %-5p %-30.30c{1} - %m%n - diff --git a/datavec/datavec-data/datavec-data-image/pom.xml b/datavec/datavec-data/datavec-data-image/pom.xml index fc85482a4..0e4267659 100644 --- a/datavec/datavec-data/datavec-data-image/pom.xml +++ b/datavec/datavec-data/datavec-data-image/pom.xml @@ -37,11 +37,6 @@ ${logback.version} test - - org.nd4j - nd4j-buffer - ${nd4j.version} - com.github.jai-imageio jai-imageio-core diff --git a/datavec/datavec-geo/pom.xml b/datavec/datavec-data/datavec-geo/pom.xml similarity index 100% rename from datavec/datavec-geo/pom.xml rename to datavec/datavec-data/datavec-geo/pom.xml diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/geo/LocationType.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/reduce/geo/CoordinatesReduction.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/CoordinatesDistanceTransform.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/GeoIPFetcher.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToCoordinatesTransform.java diff --git a/datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java b/datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java similarity index 100% rename from datavec/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java rename to datavec/datavec-data/datavec-geo/src/main/java/org/datavec/api/transform/transform/geo/IPAddressToLocationTransform.java diff --git a/datavec/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java similarity index 100% rename from datavec/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java rename to datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/reduce/TestGeoReduction.java diff --git a/datavec/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java b/datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java similarity index 100% rename from datavec/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java rename to datavec/datavec-data/datavec-geo/src/test/java/org/datavec/api/transform/transform/TestGeoTransforms.java diff --git a/datavec/datavec-hadoop/pom.xml b/datavec/datavec-data/datavec-hadoop/pom.xml similarity index 100% rename from datavec/datavec-hadoop/pom.xml rename to datavec/datavec-data/datavec-hadoop/pom.xml diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/conf/ConfigurationUtil.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/IndexToKey.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileRecordReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/MapFileSequenceRecordReader.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/index/LongIndexToKey.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/RecordWritable.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/reader/mapfile/record/SequenceRecordWritable.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/AbstractMapFileWriter.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileRecordWriter.java diff --git a/datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/main/java/org/datavec/hadoop/records/writer/mapfile/MapFileSequenceRecordWriter.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/conf/TestConfigurationUtil.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReader.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultipleParts.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/reader/TestMapFileRecordReaderMultiplePartsSomeEmpty.java diff --git a/datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java b/datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java similarity index 100% rename from datavec/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java rename to datavec/datavec-data/datavec-hadoop/src/test/java/org/datavec/hadoop/records/writer/TestMapFileRecordWriter.java diff --git a/datavec/datavec-hadoop/src/test/resources/log4j.properties b/datavec/datavec-data/datavec-hadoop/src/test/resources/log4j.properties similarity index 100% rename from datavec/datavec-hadoop/src/test/resources/log4j.properties rename to datavec/datavec-data/datavec-hadoop/src/test/resources/log4j.properties diff --git a/datavec/datavec-hadoop/src/test/resources/logback.xml b/datavec/datavec-data/datavec-hadoop/src/test/resources/logback.xml similarity index 100% rename from datavec/datavec-hadoop/src/test/resources/logback.xml rename to datavec/datavec-data/datavec-hadoop/src/test/resources/logback.xml diff --git a/datavec/datavec-perf/pom.xml b/datavec/datavec-perf/pom.xml deleted file mode 100644 index a51b9aba1..000000000 --- a/datavec/datavec-perf/pom.xml +++ /dev/null @@ -1,65 +0,0 @@ - - - - - - - datavec-parent - org.datavec - 1.0.0-SNAPSHOT - - 4.0.0 - - datavec-perf - - datavec-perf - - - UTF-8 - 1.7 - 1.7 - - - - - org.slf4j - slf4j-api - ${slf4j.version} - - - org.datavec - datavec-data-image - ${project.version} - test - - - org.datavec - datavec-api - ${project.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java b/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java deleted file mode 100644 index aa82194e1..000000000 --- a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/IOTiming.java +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.perf.timing; - -import lombok.val; -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.split.InputStreamInputSplit; -import org.datavec.api.writable.Writable; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.performance.PerformanceTracker; -import org.nd4j.linalg.memory.MemcpyDirection; - -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.util.List; -import java.util.Map; - - - -/** - * Timing components of a data vec pipeline - * consisting of: - * {@link RecordReader}, {@link InputStreamInputSplit} - * (note that this uses input stream input split, - * the record reader must support {@link InputStreamInputSplit} for this to work) - * - * @author Adam Gibson - */ -public class IOTiming { - - - /** - * Returns statistics for components of a datavec pipeline - * averaged over the specified number of times - * @param nTimes the number of times to run the pipeline for averaging - * @param recordReader the record reader - * @param file the file to read - * @param function the function - * @return the averaged {@link TimingStatistics} for input/output on a record - * reader and ndarray creation (based on the given function - * @throws Exception - */ - public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception { - TimingStatistics timingStatistics = null; - for(int i = 0; i < nTimes; i++) { - TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function); - if(timingStatistics == null) - timingStatistics = timingStatistics1; - else { - timingStatistics = timingStatistics.add(timingStatistics1); - } - - } - - return timingStatistics.average(nTimes); - } - - /** - * - * @param reader - * @param inputStream - * @param function - * @return - * @throws Exception - */ - public static TimingStatistics timeNDArrayCreation(RecordReader reader, - InputStream inputStream, - INDArrayCreationFunction function) throws Exception { - - - reader.initialize(new InputStreamInputSplit(inputStream)); - long longNanos = System.nanoTime(); - List next = reader.next(); - long endNanos = System.nanoTime(); - long etlDiff = endNanos - longNanos; - long startArrCreation = System.nanoTime(); - INDArray arr = function.createFromRecord(next); - long endArrCreation = System.nanoTime(); - long endCreationDiff = endArrCreation - startArrCreation; - Map> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth(); - val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE); - val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE); - - return TimingStatistics.builder() - .diskReadingTimeNanos(etlDiff) - .bandwidthNanosHostToDevice(bw) - .bandwidthDeviceToHost(deviceToHost) - .ndarrayCreationTimeNanos(endCreationDiff) - .build(); - } - - public interface INDArrayCreationFunction { - INDArray createFromRecord(List record); - } - -} diff --git a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java b/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java deleted file mode 100644 index 615d9659a..000000000 --- a/datavec/datavec-perf/src/main/java/org/datavec/perf/timing/TimingStatistics.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.perf.timing; - -import lombok.Builder; -import lombok.Data; - - -/** - * The timing statistics for a data pipeline including: - * ndarray creation time in nanoseconds - * disk reading time in nanoseconds - * bandwidth used in host to device in nano seconds - * bandwidth device to host in nanoseconds - * - * @author Adam Gibson - */ -@Builder -@Data -public class TimingStatistics { - - private long ndarrayCreationTimeNanos; - private long diskReadingTimeNanos; - private long bandwidthNanosHostToDevice; - private long bandwidthDeviceToHost; - - - /** - * Accumulate the given statistics - * @param timingStatistics the statistics to add - * @return the added statistics - */ - public TimingStatistics add(TimingStatistics timingStatistics) { - return TimingStatistics.builder() - .ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos) - .bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice) - .diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos) - .bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost) - .build(); - } - - - /** - * Average the results relative to the number of n. - * This method is meant to be used alongside - * {@link #add(TimingStatistics)} - * accumulated a number of times - * @param n n the number of elements - * @return the averaged results - */ - public TimingStatistics average(long n) { - return TimingStatistics.builder() - .ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n) - .bandwidthDeviceToHost(bandwidthDeviceToHost / n) - .diskReadingTimeNanos(diskReadingTimeNanos / n) - .bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n) - .build(); - } - -} diff --git a/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java b/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java deleted file mode 100644 index 60c206232..000000000 --- a/datavec/datavec-perf/src/test/java/org/datavec/datavec/timing/IOTimingTest.java +++ /dev/null @@ -1,60 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.datavec.datavec.timing; - -import org.datavec.api.records.reader.RecordReader; -import org.datavec.api.writable.NDArrayWritable; -import org.datavec.api.writable.Writable; -import org.datavec.image.loader.NativeImageLoader; -import org.datavec.image.recordreader.ImageRecordReader; -import org.datavec.perf.timing.IOTiming; -import org.datavec.perf.timing.TimingStatistics; -import org.junit.Test; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.io.ClassPathResource; - -import java.util.List; - -public class IOTimingTest { - - @Test - public void testTiming() throws Exception { - final RecordReader image = new ImageRecordReader(28,28); - final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28); - - TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() { - @Override - public INDArray createFromRecord(List record) { - NDArrayWritable imageWritable = (NDArrayWritable) record.get(0); - return imageWritable.get(); - } - }); - - System.out.println(timingStatistics); - - TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() { - @Override - public INDArray createFromRecord(List record) { - NDArrayWritable imageWritable = (NDArrayWritable) record.get(0); - return imageWritable.get(); - } - }); - - System.out.println(timingStatistics1); - } - -} diff --git a/datavec/pom.xml b/datavec/pom.xml index ff4bfa51b..c706469a9 100644 --- a/datavec/pom.xml +++ b/datavec/pom.xml @@ -59,16 +59,12 @@ datavec-api datavec-data - datavec-geo - datavec-hadoop datavec-spark - datavec-camel datavec-local datavec-spark-inference-parent datavec-jdbc datavec-excel datavec-arrow - datavec-perf datavec-python diff --git a/deeplearning4j/deeplearning4j-core/pom.xml b/deeplearning4j/deeplearning4j-core/pom.xml index 728ce9598..90c88d4c3 100644 --- a/deeplearning4j/deeplearning4j-core/pom.xml +++ b/deeplearning4j/deeplearning4j-core/pom.xml @@ -153,6 +153,17 @@ ${jaxb.version} provided + + + com.github.oshi + oshi-json + ${oshi.version} + + + com.github.oshi + oshi-core + ${oshi.version} + diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DeviceMetric.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/DiskInfo.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/HardwareMetric.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoFilePrintListener.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemInfoPrintListener.java diff --git a/deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/perf/listener/SystemPolling.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java old mode 100755 new mode 100644 similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/MovingWindowMatrix.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/ThreadUtils.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ThreadUtils.java similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/ThreadUtils.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/ThreadUtils.java diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/UIDProvider.java b/deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/UIDProvider.java similarity index 100% rename from deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/UIDProvider.java rename to deeplearning4j/deeplearning4j-core/src/main/java/org/deeplearning4j/util/UIDProvider.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/ArgmaxAdapterTest.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/adapters/Regression2dAdapterTest.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/recurrent/MaskZeroLayerTest.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java similarity index 100% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/layers/samediff/SameDiffCustomLayerTests.java diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java similarity index 91% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java index d75653604..bae025caf 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulatorTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/EncodedGradientsAccumulatorTest.java @@ -14,11 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator; +import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler; import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; @@ -68,8 +70,8 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest { accumulator.receiveUpdate(encoded); // just purge updates, like they were consumed - for (int i = 0; i < accumulator.messages.size(); i++) { - accumulator.messages.get(i).clear(); + for (int i = 0; i < accumulator.getMessages().size(); i++) { + accumulator.getMessages().get(i).clear(); } } } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java similarity index 97% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java index a7884bb56..512b77c35 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTailTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/IndexedTailTest.java @@ -14,12 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail; import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.factory.Nd4j; @@ -97,9 +98,9 @@ public class IndexedTailTest extends BaseDL4JTest { assertEquals(-1, tail.maxAppliedIndexEverywhere()); // 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements - tail.positions.get(11L).set(2); - tail.positions.get(22L).set(2); - tail.positions.get(33L).set(3); + tail.getPositions().get(11L).set(2); + tail.getPositions().get(22L).set(2); + tail.getPositions().get(33L).set(3); // all elements including this index are safe to remove, because they were consumed everywhere assertEquals(2, tail.maxAppliedIndexEverywhere()); @@ -197,10 +198,10 @@ public class IndexedTailTest extends BaseDL4JTest { Nd4j.getExecutioner().commit(); } - assertTrue(tail.collapsedMode.get()); + assertTrue(tail.getCollapsedMode().get()); assertEquals(1, tail.updatesSize()); - val array = tail.updates.get(32L); + val array = tail.getUpdates().get(32L); assertNotNull(array); assertEquals(sum, (int) array.getDouble(0)); } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java similarity index 98% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java index f5afb0e48..6496fd9e3 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/SmartFancyBlockingQueueTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/SmartFancyBlockingQueueTest.java @@ -14,12 +14,13 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import lombok.extern.slf4j.Slf4j; import lombok.val; import org.apache.commons.lang3.RandomUtils; import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue; import org.deeplearning4j.util.ThreadUtils; import org.junit.Ignore; import org.junit.Test; diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java similarity index 98% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java index 52744bf12..578674ed6 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/solvers/accumulation/ThresholdAlgorithmTests.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/accumulation/ThresholdAlgorithmTests.java @@ -14,7 +14,7 @@ * SPDX-License-Identifier: Apache-2.0 ******************************************************************************/ -package org.deeplearning4j.optimize.solvers.accumulation; +package org.deeplearning4j.optimize.solver.accumulation; import org.deeplearning4j.BaseDL4JTest; import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm; diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java similarity index 96% rename from deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java index 5a3588e25..fdc86c6cc 100644 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/optimize/listeners/ScoreStatTest.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimizer/listener/ScoreStatTest.java @@ -1,5 +1,6 @@ -package org.deeplearning4j.optimize.listeners; +package org.deeplearning4j.optimizer.listener; +import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener; import org.junit.Ignore; import org.junit.Test; diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/SystemPollingTest.java diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestHardWareMetric.java diff --git a/deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java similarity index 100% rename from deeplearning4j/dl4j-perf/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java rename to deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/perf/listener/TestSystemInfoPrintListener.java diff --git a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java index f1ea9f1e1..9e43f042b 100644 --- a/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java +++ b/deeplearning4j/deeplearning4j-cuda/src/test/java/org/deeplearning4j/gradientcheck/CuDNNGradientChecks.java @@ -305,7 +305,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, false, -1, excludeParams, null); assertTrue(gradOK); } @@ -417,7 +417,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { } boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32, null, null); assertTrue(gradOK); } @@ -655,9 +655,12 @@ public class CuDNNGradientChecks extends BaseDL4JTest { }; log.info("*** Starting test: " + msg + " ***"); - boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null, - false, -1, null, c); + boolean gradOK = GradientCheckUtil.checkGradients( + new GradientCheckUtil.MLNConfig().net(mln).epsilon(DEFAULT_EPS) + .maxRelError(DEFAULT_MAX_REL_ERROR).minAbsoluteError(DEFAULT_MIN_ABS_ERROR) + .print(PRINT_RESULTS ? GradientCheckUtil.PrintMode.ZEROS : GradientCheckUtil.PrintMode.FAILURES_ONLY) + .exitOnFirstError(RETURN_ON_FIRST_FAILURE) + .input(f).labels(l).callEachIter(c)); assertTrue(msg, gradOK); TestUtils.testModelSerialization(mln); @@ -691,7 +694,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest { //However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter" Set excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev")); boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR, - DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, excludeParams); + DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, false, -1, excludeParams, null); assertTrue(gradOK); diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java index 760d2076c..7448581c5 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/EncodedGradientsAccumulator.java @@ -63,6 +63,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist protected int parties; @Getter protected MessageHandler handler; + @Getter protected List> messages = new ArrayList<>(); protected List workspaces = new ArrayList<>(); protected List locks = new ArrayList<>(); @@ -106,7 +107,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode); } - protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, + public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory, int queueSize, Double boundary, boolean encodingDebugMode) { this.parties = parties; this.handler = handler; diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java index b53986102..c04440cd9 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/optimize/solvers/accumulation/IndexedTail.java @@ -16,6 +16,7 @@ package org.deeplearning4j.optimize.solvers.accumulation; +import lombok.Getter; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; import lombok.val; @@ -44,9 +45,11 @@ import java.util.concurrent.locks.ReentrantReadWriteLock; @Slf4j public class IndexedTail { // here we store positions of individual consumers + @Getter protected ConcurrentHashMap positions = new ConcurrentHashMap<>(); // here we store individual updates + @Getter protected Map updates = new ConcurrentHashMap<>(); // simple counter for new updates @@ -67,6 +70,7 @@ public class IndexedTail { protected final boolean allowCollapse; protected final long[] shape; protected final int collapseThreshold = 32; + @Getter protected AtomicBoolean collapsedMode = new AtomicBoolean(false); protected AtomicLong collapsedIndex = new AtomicLong(-1); @@ -148,7 +152,7 @@ public class IndexedTail { } } - protected long firstNotAppliedIndexEverywhere() { + public long firstNotAppliedIndexEverywhere() { long maxIdx = -1; // if there's no updates posted yet - just return negative value @@ -163,7 +167,7 @@ public class IndexedTail { return maxIdx + 1; } - protected long maxAppliedIndexEverywhere() { + public long maxAppliedIndexEverywhere() { long maxIdx = Long.MAX_VALUE; for (val v:positions.values()) { if (v.get() < maxIdx) @@ -212,7 +216,7 @@ public class IndexedTail { return getDelta(Thread.currentThread().getId()); } - protected long getDelta(long threadId) { + public long getDelta(long threadId) { return getGlobalPosition() - getLocalPosition(threadId); } @@ -293,7 +297,7 @@ public class IndexedTail { /** * This method does maintenance of updates within */ - protected synchronized void maintenance() { + public synchronized void maintenance() { // first of all we're checking, if all consumers were already registered. if not - just no-op. if (positions.size() < expectedConsumers) { log.trace("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", positions.size(), expectedConsumers); @@ -326,7 +330,7 @@ public class IndexedTail { * This method returns actual number of updates stored within tail * @return */ - protected int updatesSize() { + public int updatesSize() { return updates.size(); } @@ -348,11 +352,11 @@ public class IndexedTail { return result; } - protected boolean isDead() { + public boolean isDead() { return dead.get(); } - protected void notifyDead() { + public void notifyDead() { dead.set(true); } diff --git a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java b/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java deleted file mode 100644 index c60822ef7..000000000 --- a/deeplearning4j/deeplearning4j-nn/src/test/java/org/deeplearning4j/TestUtils.java +++ /dev/null @@ -1,158 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j; - -import org.apache.commons.compress.utils.IOUtils; -import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; -import org.deeplearning4j.nn.conf.MultiLayerConfiguration; -import org.deeplearning4j.nn.graph.ComputationGraph; -import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; -import org.deeplearning4j.util.ModelSerializer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; -import org.nd4j.linalg.factory.Nd4j; - -import java.io.*; -import java.util.Random; - -import static org.junit.Assert.assertEquals; - -public class TestUtils { - - public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){ - - MultiLayerNetwork restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreMultiLayerNetwork(bais, true); - - assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the MultiLayerConfiguration is serializable (required by Spark etc) - MultiLayerConfiguration conf = net.getLayerWiseConfigurations(); - serializeDeserializeJava(conf); - - return restored; - } - - public static ComputationGraph testModelSerialization(ComputationGraph net){ - - ComputationGraph restored; - try { - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - ModelSerializer.writeModel(net, baos, true); - byte[] bytes = baos.toByteArray(); - - ByteArrayInputStream bais = new ByteArrayInputStream(bytes); - restored = ModelSerializer.restoreComputationGraph(bais, true); - - assertEquals(net.getConfiguration(), restored.getConfiguration()); - assertEquals(net.params(), restored.params()); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - //Also check the ComputationGraphConfiguration is serializable (required by Spark etc) - ComputationGraphConfiguration conf = net.getConfiguration(); - serializeDeserializeJava(conf); - - return restored; - } - - private static T serializeDeserializeJava(T object){ - byte[] bytes; - try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){ - oos.writeObject(object); - oos.close(); - bytes = baos.toByteArray(); - } catch (IOException e){ - //Should never happen - throw new RuntimeException(e); - } - - T out; - try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){ - out = (T)ois.readObject(); - } catch (IOException | ClassNotFoundException e){ - throw new RuntimeException(e); - } - - assertEquals(object, out); - return out; - } - - public static INDArray randomOneHot(long examples, long nOut){ - return randomOneHot(examples, nOut, new Random(12345)); - } - - public static INDArray randomOneHot(long examples, long nOut, long rngSeed){ - return randomOneHot(examples, nOut, new Random(rngSeed)); - } - - public static INDArray randomOneHot(long examples, long nOut, Random rng){ - INDArray arr = Nd4j.create(examples, nOut); - for( int i=0; i - - - 4.0.0 - - org.deeplearning4j - deeplearning4j-scaleout - 1.0.0-SNAPSHOT - - - deeplearning4j-aws_2.11 - DeepLearning4j-AWS - 1.0.0-SNAPSHOT - - - - - maven-compiler-plugin - - 1.8 - 1.8 - - - - - - - 2.11.12 - 2.11 - - - - - com.amazonaws - aws-java-sdk - 1.11.24 - - - args4j - args4j - 2.32 - - - org.slf4j - slf4j-api - - - org.nd4j - nd4j-api - ${nd4j.version} - - - org.deeplearning4j - deeplearning4j-util - ${project.version} - - - - com.jcraft - jsch - ${jsch.version} - - - - org.threadly - threadly - ${threadly.version} - - - - org.apache.commons - commons-lang3 - ${commons-lang3.version} - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java deleted file mode 100755 index a61a5da69..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/dataset/DataSetLoader.java +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.dataset; - -import org.deeplearning4j.aws.s3.reader.S3Downloader; - -import java.io.InputStream; - -public class DataSetLoader { - - private String bucket; - - - - public void onData(InputStream is) { - S3Downloader downloader = new S3Downloader(); - - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java deleted file mode 100755 index 2d0a58685..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/Ec2BoxCreator.java +++ /dev/null @@ -1,222 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2; - -import com.amazonaws.regions.Regions; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.model.*; -import org.deeplearning4j.aws.s3.BaseS3; -import org.deeplearning4j.util.ThreadUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -/** - * Creates Ec2Boxes - * @author Adam Gibson - * - */ -public class Ec2BoxCreator extends BaseS3 { - - - private String amiId; - private int numBoxes; - private String size; - private List boxesCreated; - private String securityGroupId; - private String keyPair; - private Regions regions = Regions.DEFAULT_REGION; - private static final Logger log = LoggerFactory.getLogger(Ec2BoxCreator.class); - - //centos - public final static String DEFAULT_AMI = "ami-8997afe0"; - - /** - * - * @param numBoxes number of boxes - * @param size the size of the instances - */ - public Ec2BoxCreator(int numBoxes, String size, String securityGroupId, String keyPair) { - this(DEFAULT_AMI, numBoxes, size, securityGroupId, keyPair); - } - - - /** - * - * @param amiId amazon image id - * @param numBoxes number of boxes - * @param size the size of the instances - * @param securityGroupId - */ - public Ec2BoxCreator(String amiId, int numBoxes, String size, String securityGroupId, String keyPair) { - super(); - this.amiId = amiId; - this.numBoxes = numBoxes; - this.size = size; - this.keyPair = keyPair; - this.securityGroupId = securityGroupId; - } - - public void createSpot() { - // Initializes a Spot Instance Request - RequestSpotInstancesRequest requestRequest = new RequestSpotInstancesRequest(); - - // Request 1 x t1.micro instance with a bid price of $0.03. - requestRequest.setSpotPrice("0.03"); - requestRequest.setInstanceCount(Integer.valueOf(1)); - - // Setup the specifications of the launch. This includes the - // instance type (e.g. t1.micro) and the latest Amazon Linux - // AMI id available. Note, you should always use the latest - // Amazon Linux AMI id or another of your choosing. - LaunchSpecification launchSpecification = new LaunchSpecification(); - launchSpecification.setImageId("ami-8c1fece5"); - launchSpecification.setInstanceType("t1.micro"); - - // Add the security group to the request. - List securityGroups = new ArrayList<>(); - securityGroups.add("GettingStartedGroup"); - launchSpecification.setSecurityGroups(securityGroups); - - // Add the launch specifications to the request. - requestRequest.setLaunchSpecification(launchSpecification); - - // Call the RequestSpotInstance API. - RequestSpotInstancesResult requestResult = getEc2().requestSpotInstances(requestRequest); - - - List requestResponses = requestResult.getSpotInstanceRequests(); - - // Setup an arraylist to collect all of the request ids we want to - // watch hit the running state. - List spotInstanceRequestIds = new ArrayList<>(); - - // Add all of the request ids to the hashset, so we can determine when they hit the - // active state. - for (SpotInstanceRequest requestResponse : requestResponses) { - System.out.println("Created Spot Request: " + requestResponse.getSpotInstanceRequestId()); - spotInstanceRequestIds.add(requestResponse.getSpotInstanceRequestId()); - } - - } - - public void setRegion(Regions regions) { - this.regions = regions; - } - - - /** - * Create the instances - */ - public void create() { - RunInstancesRequest runInstancesRequest = - new RunInstancesRequest().withImageId(amiId).withInstanceType(size).withKeyName(keyPair) - .withMinCount(1).withSecurityGroupIds(securityGroupId).withMaxCount(numBoxes); - AmazonEC2 ec2 = getEc2(); - ec2.setRegion(com.amazonaws.regions.Region.getRegion(regions)); - List boxes = ec2.runInstances(runInstancesRequest).getReservation().getInstances(); - if (boxesCreated == null) { - boxesCreated = new ArrayList<>(); - for (Instance i : boxes) - boxesCreated.add(i.getInstanceId()); - - - - log.info("Boxes created " + boxesCreated); - } else { - blowupBoxes(); - boxesCreated.clear(); - for (Instance i : boxes) - boxesCreated.add(i.getInstanceId()); - - } - } - - - - public List blowupBoxes() { - TerminateInstancesRequest request = new TerminateInstancesRequest().withInstanceIds(boxesCreated); - - if (boxesCreated != null) { - TerminateInstancesResult result = getEc2().terminateInstances(request); - List change = result.getTerminatingInstances(); - log.info("Boxes destroyed " + boxesCreated); - return change; - } - - return Collections.emptyList(); - } - - - public void blockTillAllRunning() { - while (!allRunning()) { - ThreadUtils.uncheckedSleep(1000); - log.info("Not all created..."); - } - } - - public boolean allRunning() { - if (boxesCreated == null) - return false; - else { - DescribeInstancesResult result = getEc2().describeInstances(); - List reservations = result.getReservations(); - for (Reservation r : reservations) { - List instances = r.getInstances(); - for (Instance instance : instances) { - InstanceState state = instance.getState(); - if (state.getCode() == 48) - continue; - if (state.getCode() != 16) - return false; - } - } - - return true; - } - - - } - - public List getHosts() { - DescribeInstancesResult result = getEc2().describeInstances(); - List hosts = new ArrayList<>(); - List reservations = result.getReservations(); - for (Reservation r : reservations) { - List instances = r.getInstances(); - for (Instance instance : instances) { - InstanceState state = instance.getState(); - if (state.getCode() == 48) - continue; - hosts.add(instance.getPublicDnsName()); - - } - } - - return hosts; - } - - public List getBoxesCreated() { - return boxesCreated; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java deleted file mode 100755 index 94ac12c4b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/ClusterSetup.java +++ /dev/null @@ -1,122 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2.provision; - -import com.amazonaws.regions.Regions; -import org.deeplearning4j.aws.ec2.Ec2BoxCreator; -import org.kohsuke.args4j.CmdLineException; -import org.kohsuke.args4j.CmdLineParser; -import org.kohsuke.args4j.Option; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.threadly.concurrent.PriorityScheduler; - -import java.util.List; - -/** - * Sets up a DL4J cluster - * @author Adam Gibson - * - */ -public class ClusterSetup { - - @Option(name = "-w", usage = "Number of workers") - private int numWorkers = 1; - @Option(name = "-ami", usage = "Amazon machine image: default, amazon linux (only works with RHEL right now") - private String ami = "ami-fb8e9292"; - @Option(name = "-s", usage = "size of instance: default m1.medium") - private String size = "m3.xlarge"; - @Option(name = "-sg", usage = "security group, this needs to be applyTransformToDestination") - private String securityGroupName; - @Option(name = "-kp", usage = "key pair name, also needs to be applyTransformToDestination.") - private String keyPairName; - @Option(name = "-kpath", - usage = "path to private key - needs to be applyTransformToDestination, this is used to login to amazon.") - private String pathToPrivateKey; - @Option(name = "-wscript", usage = "path to worker script to run, this will allow customization of dependencies") - private String workerSetupScriptPath; - @Option(name = "-mscript", usage = "path to master script to run this will allow customization of the dependencies") - private String masterSetupScriptPath; - @Option(name = "-region", usage = "specify a region") - private String region = Regions.US_EAST_1.getName(); - - private PriorityScheduler as; - - private static final Logger log = LoggerFactory.getLogger(ClusterSetup.class); - - - public ClusterSetup(String[] args) { - CmdLineParser parser = new CmdLineParser(this); - try { - parser.parseArgument(args); - } catch (CmdLineException e) { - parser.printUsage(System.err); - log.error("Unable to parse args", e); - } - - - } - - public void exec() { - //master + workers - Ec2BoxCreator boxCreator = new Ec2BoxCreator(ami, numWorkers, size, securityGroupName, keyPairName); - boxCreator.setRegion(Regions.fromName(region)); - boxCreator.create(); - boxCreator.blockTillAllRunning(); - List hosts = boxCreator.getHosts(); - //provisionMaster(hosts.get(0)); - provisionWorkers(hosts); - - - } - - - - private void provisionWorkers(List workers) { - as = new PriorityScheduler(Runtime.getRuntime().availableProcessors()); - for (final String workerHost : workers) { - try { - as.execute(new Runnable() { - @Override - public void run() { - HostProvisioner uploader = new HostProvisioner(workerHost, "ec2-user"); - try { - uploader.addKeyFile(pathToPrivateKey); - //uploader.runRemoteCommand("sudo hostname " + workerHost); - uploader.uploadAndRun(workerSetupScriptPath, ""); - } catch (Exception e) { - e.printStackTrace(); - } - - } - }); - - } catch (Exception e) { - log.error("Error ", e); - } - } - } - - - /** - * @param args - */ - public static void main(String[] args) { - new ClusterSetup(args).exec(); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java deleted file mode 100755 index 846d8b200..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/DistributedDeepLearningTrainer.java +++ /dev/null @@ -1,31 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2.provision; - -public class DistributedDeepLearningTrainer { - - private DistributedDeepLearningTrainer() {} - - /** - * @param args - */ - public static void main(String[] args) { - ClusterSetup clusterSet = new ClusterSetup(args); - - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java deleted file mode 100755 index c8bbd386a..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/ec2/provision/HostProvisioner.java +++ /dev/null @@ -1,311 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.ec2.provision; - -import com.jcraft.jsch.*; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.BufferedInputStream; -import java.io.File; -import java.io.FileInputStream; -import java.io.InputStream; -import java.util.Arrays; -import java.util.Collection; - -/** - * Meant for uploading files to remote servers - * @author Adam Gibson - * - */ -public class HostProvisioner implements UserInfo { - - private String host; - private JSch jsch; - private String user; - private int port = 22; - private String password; - private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class); - - /** - * - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - * @param password the password to use if any - * @param port the port to connect to(default 22) - */ - public HostProvisioner(String host, String user, String password, int port) { - super(); - this.host = host; - this.user = user; - this.port = port; - this.password = password; - jsch = new JSch(); - - } - - /** - * Connects to port 22 - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - * @param password the password to use if any - */ - public HostProvisioner(String host, String user, String password) { - this(host, user, password, 22); - } - - /** - * Connects to port 22 - * @param host host to connect to (public facing dns) - * @param user the user to connect with (default root otherwise) - */ - public HostProvisioner(String host, String user) { - this(host, user, "", 22); - } - - /** - * Connects to port 22, user root, with no password - * @param host host to connect to (public facing dns) - */ - public HostProvisioner(String host) { - this(host, "root", "", 22); - } - - - - public void uploadAndRun(String script, String rootDir) throws Exception { - String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName(); - upload(new File(script), remoteName); - - String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName; - remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand; - runRemoteCommand(remoteCommand); - } - - public void runRemoteCommand(String remoteCommand) throws Exception { - Session session = getSession(); - session.connect(); - ChannelExec channel = (ChannelExec) session.openChannel("exec"); - - - channel.setCommand(remoteCommand); - channel.setErrStream(System.err); - channel.setPty(true); - - channel.setOutputStream(System.out); - channel.connect(); - channel.start(); - InputStream input = channel.getInputStream(); - - //start reading the input from the executed commands on the shell - byte[] tmp = new byte[60000]; - while (true) { - while (input.available() > 0) { - int i = input.read(tmp, 0, tmp.length); - if (i < 0) - break; - log.info(new String(tmp, 0, i)); - } - if (channel.isClosed()) { - log.info("exit-status: " + channel.getExitStatus()); - break; - } - - } - - channel.disconnect(); - session.disconnect(); - - - } - - - private Session getSession() throws Exception { - Session session = jsch.getSession(user, host, port); - session.setUserInfo(this); - return session; - } - - /** - * Creates the directory for the file if necessary - * and uploads the file - * @param from the directory to upload from - * @param to the destination directory on the remote server - * @throws Exception - */ - public void uploadForDeployment(String from, String to) throws Exception { - File fromFile = new File(from); - if (!to.isEmpty() && fromFile.isDirectory()) - mkDir(to); - else - upload(from, to); - - - } - - public void addKeyFile(String keyFile) throws Exception { - jsch.addIdentity(keyFile); - } - - //creates the directory to upload to - private void mkDir(String dir) throws Exception { - Session session = getSession(); - session.connect(); - Channel channel = session.openChannel("sftp"); - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - if (!fileExists(dir, c)) - c.mkdir(dir); - c.exit(); - session.disconnect(); - } - - private boolean fileExists(String dir, ChannelSftp channel) { - try { - channel.stat(dir); - return true; - } catch (Exception e) { - return false; - } - } - - - //uploads the file or listed files in a directory - private void upload(String fileOrDir, String uploadRootDir) throws Exception { - if (uploadRootDir.isEmpty()) - uploadRootDir = "."; - File origin = new File(fileOrDir); - - if (fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) { - upload(new File(fileOrDir), uploadRootDir); - untar(uploadRootDir); - } else if (origin.isFile()) { - upload(new File(fileOrDir), uploadRootDir); - } else { - File[] childFiles = origin.listFiles(); - if (childFiles != null) - upload(Arrays.asList(childFiles), uploadRootDir); - - } - } - - private void untar(String targetRemoteFile) throws Exception { - this.runRemoteCommand("tar xvf " + targetRemoteFile); - } - - private void upload(Collection files, String rootDir) throws Exception { - Session session = getSession(); - session.connect(); - Channel channel = session.openChannel("sftp"); - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - for (File f : files) { - if (f.isDirectory()) { - log.warn("Skipping " + f.getName()); - continue; - } - - log.info("Uploading " + f.getName()); - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - c.put(bis, rootDir + "/" + f.getName()); - bis.close(); - - } - - channel.disconnect(); - session.disconnect(); - - - } - - private void upload(File f, String remoteFile) throws Exception { - Session session = getSession(); - int numRetries = 0; - while (numRetries < 3 && !session.isConnected()) { - try { - session.connect(); - } catch (Exception e) { - numRetries++; - } - } - - try { - Channel channel = session.openChannel("sftp"); - - - channel.connect(); - - ChannelSftp c = (ChannelSftp) channel; - - BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f)); - if (this.fileExists(remoteFile, c)) - if (f.isDirectory()) - c.rmdir(remoteFile); - else - c.rm(remoteFile); - c.put(bis, remoteFile); - bis.close(); - c.exit(); - session.disconnect(); - } catch (Exception e) { - log.info("Session was down...trying again", e); - upload(f, remoteFile); - } - } - - - - @Override - public String getPassphrase() { - return this.password; - } - - - @Override - public String getPassword() { - return this.password; - } - - - @Override - public boolean promptPassphrase(String arg0) { - return true; - } - - - @Override - public boolean promptPassword(String arg0) { - return true; - - } - - - @Override - public boolean promptYesNo(String arg0) { - return true; - - } - - - @Override - public void showMessage(String arg0) { - log.info(arg0); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java deleted file mode 100644 index 949738bbd..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/EmrConfig.java +++ /dev/null @@ -1,47 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.emr; - -import com.amazonaws.services.elasticmapreduce.model.Configuration; - -import lombok.*; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; - - -@Data -@AllArgsConstructor(access = AccessLevel.PRIVATE) -@NoArgsConstructor -@Builder -public class EmrConfig { - - protected String classification; - protected Map properties; - protected List configs; - - Configuration toAwsConfig() { - Configuration config = new Configuration().withClassification(classification).withProperties(properties); - List subConfigs = new ArrayList<>(); - for (EmrConfig conf : configs){ - subConfigs.add(conf.toAwsConfig()); - } - return config.withConfigurations(subConfigs); - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java deleted file mode 100644 index d179cca09..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/emr/SparkEMRClient.java +++ /dev/null @@ -1,531 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.emr; - -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce; -import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder; -import com.amazonaws.services.elasticmapreduce.model.*; -import com.amazonaws.services.s3.AmazonS3ClientBuilder; -import com.amazonaws.services.s3.AmazonS3URI; -import com.amazonaws.services.s3.model.PutObjectRequest; -import lombok.AccessLevel; -import lombok.AllArgsConstructor; -import lombok.Data; -import lombok.NoArgsConstructor; -import lombok.extern.slf4j.Slf4j; -import org.apache.commons.lang3.RandomStringUtils; -import org.nd4j.linalg.function.Function; - -import java.io.File; -import java.util.*; - -/** - * Configuration for a Spark EMR cluster - */ -@Data -@AllArgsConstructor(access = AccessLevel.PRIVATE) -@NoArgsConstructor -@Slf4j -public class SparkEMRClient { - - protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12); - protected String sparkAwsRegion = "us-east-1"; - protected String sparkEmrRelease = "emr-5.9.0"; - protected String sparkEmrServiceRole = "EMR_DefaultRole"; - protected List sparkEmrConfigs = Collections.emptyList(); - protected String sparkSubnetId = null; - protected List sparkSecurityGroupIds = Collections.emptyList(); - protected int sparkInstanceCount = 1; - protected String sparkInstanceType = "m3.xlarge"; - protected Optional sparkInstanceBidPrice = Optional.empty(); - protected String sparkInstanceRole = "EMR_EC2_DefaultRole"; - protected String sparkS3JarFolder = "changeme"; - protected int sparkTimeoutDurationMinutes = 90; - - //underlying configs - protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder; - protected AmazonS3ClientBuilder sparkS3ClientBuilder; - protected JobFlowInstancesConfig sparkJobFlowInstancesConfig; - protected RunJobFlowRequest sparkRunJobFlowRequest; - protected Function sparkS3PutObjectDecorator; - protected Map sparkSubmitConfs; - - - private static ClusterState[] activeClusterStates = new ClusterState[]{ - ClusterState.RUNNING, - ClusterState.STARTING, - ClusterState.WAITING, - ClusterState.BOOTSTRAPPING}; - - private Optional findClusterWithName(AmazonElasticMapReduce emr, String name) { - List csrl = emr.listClusters((new ListClustersRequest()).withClusterStates(activeClusterStates)).getClusters(); - for (ClusterSummary csr : csrl) { - if (csr.getName().equals(name)) return Optional.of(csr); - } - return Optional.empty(); - } - - /** - * Creates the current cluster - */ - public void createCluster() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional csr = findClusterWithName(emr, sparkClusterName); - if (csr.isPresent()) { - String msg = String.format("A cluster with name %s and id %s is already deployed", sparkClusterName, csr.get().getId()); - log.error(msg); - throw new IllegalStateException(msg); - } else { - RunJobFlowResult res = emr.runJobFlow(sparkRunJobFlowRequest); - String msg = String.format("Your cluster is launched with name %s and id %s.", sparkClusterName, res.getJobFlowId()); - log.info(msg); - } - - } - - private void logClusters(List csrl) { - if (csrl.isEmpty()) log.info("No cluster found."); - else { - log.info(String.format("%d clusters found.", csrl.size())); - for (ClusterSummary csr : csrl) { - log.info(String.format("Name: %s | Id: %s", csr.getName(), csr.getId())); - } - } - } - - /** - * Lists existing active clusters Names - * - * @return cluster names - */ - public List listActiveClusterNames() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - List csrl = - emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters(); - logClusters(csrl); - List res = new ArrayList<>(csrl.size()); - for (ClusterSummary csr : csrl) res.add(csr.getName()); - return res; - } - - /** - * List existing active cluster IDs - * - * @return cluster IDs - */ - public List listActiveClusterIds() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - List csrl = - emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters(); - logClusters(csrl); - List res = new ArrayList<>(csrl.size()); - for (ClusterSummary csr : csrl) res.add(csr.getId()); - return res; - } - - - /** - * Terminates a cluster - */ - public void terminateCluster() { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional optClusterSum = findClusterWithName(emr, sparkClusterName); - if (!optClusterSum.isPresent()) { - log.error(String.format("The cluster with name %s , requested for deletion, does not exist.", sparkClusterName)); - } else { - String id = optClusterSum.get().getId(); - emr.terminateJobFlows((new TerminateJobFlowsRequest()).withJobFlowIds(id)); - log.info(String.format("The cluster with id %s is terminating.", id)); - } - } - - // The actual job-sumission logic - private void submitJob(AmazonElasticMapReduce emr, String mainClass, List args, Map sparkConfs, File uberJar) throws Exception { - AmazonS3URI s3Jar = new AmazonS3URI(sparkS3JarFolder + "/" + uberJar.getName()); - log.info(String.format("Placing uberJar %s to %s", uberJar.getPath(), s3Jar.toString())); - PutObjectRequest putRequest = sparkS3PutObjectDecorator.apply( - new PutObjectRequest(s3Jar.getBucket(), s3Jar.getKey(), uberJar) - ); - sparkS3ClientBuilder.build().putObject(putRequest); - // The order of these matters - List sparkSubmitArgs = Arrays.asList( - "spark-submit", - "--deploy-mode", - "cluster", - "--class", - mainClass - ); - for (Map.Entry e : sparkConfs.entrySet()) { - sparkSubmitArgs.add(String.format("--conf %s = %s ", e.getKey(), e.getValue())); - } - sparkSubmitArgs.add(s3Jar.toString()); - sparkSubmitArgs.addAll(args); - StepConfig step = new StepConfig() - .withActionOnFailure(ActionOnFailure.CONTINUE) - .withName("Spark step") - .withHadoopJarStep( - new HadoopJarStepConfig() - .withJar("command-runner.jar") - .withArgs(sparkSubmitArgs) - ); - - Optional optCsr = findClusterWithName(emr, sparkClusterName); - if (optCsr.isPresent()) { - ClusterSummary csr = optCsr.get(); - emr.addJobFlowSteps( - new AddJobFlowStepsRequest() - .withJobFlowId(csr.getId()) - .withSteps(step)); - log.info( - String.format("Your job is added to the cluster with id %s.", csr.getId()) - ); - } else { - // If the cluster wasn't started, it's assumed ot be throwaway - List steps = sparkRunJobFlowRequest.getSteps(); - steps.add(step); - RunJobFlowRequest jobFlowRequest = sparkRunJobFlowRequest - .withSteps(steps) - .withInstances(sparkJobFlowInstancesConfig.withKeepJobFlowAliveWhenNoSteps(false)); - - RunJobFlowResult res = emr.runJobFlow(jobFlowRequest); - log.info("Your new cluster's id is %s.", res.getJobFlowId()); - } - - } - - /** - * Submit a Spark Job with a specified main class - */ - public void sparkSubmitJobWithMain(String[] args, String mainClass, File uberJar) throws Exception { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - submitJob(emr, mainClass, Arrays.asList(args), sparkSubmitConfs, uberJar); - } - - private void checkStatus(AmazonElasticMapReduce emr, String clusterId) throws InterruptedException { - log.info("."); - com.amazonaws.services.elasticmapreduce.model.Cluster dcr = - emr.describeCluster((new DescribeClusterRequest()).withClusterId(clusterId)).getCluster(); - String state = dcr.getStatus().getState(); - long timeOutTime = System.currentTimeMillis() + ((long) sparkTimeoutDurationMinutes * 60 * 1000); - - Boolean activated = Arrays.asList(activeClusterStates).contains(ClusterState.fromValue(state)); - Boolean timedOut = System.currentTimeMillis() > timeOutTime; - if (activated && timedOut) { - emr.terminateJobFlows( - new TerminateJobFlowsRequest().withJobFlowIds(clusterId) - ); - log.error("Timeout. Cluster terminated."); - } else if (!activated) { - Boolean hasAbnormalStep = false; - StepSummary stepS = null; - List steps = emr.listSteps(new ListStepsRequest().withClusterId(clusterId)).getSteps(); - for (StepSummary step : steps) { - if (step.getStatus().getState() != StepState.COMPLETED.toString()) { - hasAbnormalStep = true; - stepS = step; - } - } - if (hasAbnormalStep && stepS != null) - log.error(String.format("Cluster %s terminated with an abnormal step, name %s, id %s", clusterId, stepS.getName(), stepS.getId())); - else - log.info("Cluster %s terminated without error.", clusterId); - } else { - Thread.sleep(5000); - checkStatus(emr, clusterId); - } - } - - /** - * Monitor the cluster and terminates when it times out - */ - public void sparkMonitor() throws InterruptedException { - AmazonElasticMapReduce emr = sparkEmrClientBuilder.build(); - Optional optCsr = findClusterWithName(emr, sparkClusterName); - if (!optCsr.isPresent()) { - log.error(String.format("The cluster with name %s does not exist.", sparkClusterName)); - } else { - ClusterSummary csr = optCsr.get(); - log.info(String.format("found cluster with id %s, starting monitoring", csr.getId())); - checkStatus(emr, csr.getId()); - } - } - - @Data - public static class Builder { - - protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12); - protected String sparkAwsRegion = "us-east-1"; - protected String sparkEmrRelease = "emr-5.9.0"; - protected String sparkEmrServiceRole = "EMR_DefaultRole"; - protected List sparkEmrConfigs = Collections.emptyList(); - protected String sparkSubNetid = null; - protected List sparkSecurityGroupIds = Collections.emptyList(); - protected int sparkInstanceCount = 1; - protected String sparkInstanceType = "m3.xlarge"; - protected Optional sparkInstanceBidPrice = Optional.empty(); - protected String sparkInstanceRole = "EMR_EC2_DefaultRole"; - protected String sparkS3JarFolder = "changeme"; - protected int sparkTimeoutDurationMinutes = 90; - - protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder; - protected AmazonS3ClientBuilder sparkS3ClientBuilder; - protected JobFlowInstancesConfig sparkJobFlowInstancesConfig; - protected RunJobFlowRequest sparkRunJobFlowRequest; - - // This should allow the user to decorate the put call to add metadata to the jar put command, such as security groups, - protected Function sparkS3PutObjectDecorator = new Function() { - @Override - public PutObjectRequest apply(PutObjectRequest putObjectRequest) { - return putObjectRequest; - } - }; - protected Map sparkSubmitConfs; - - - /** - * Defines the EMR cluster's name - * - * @param clusterName the EMR cluster's name - * @return an EMR cluster builder - */ - public Builder clusterName(String clusterName) { - this.sparkClusterName = clusterName; - return this; - } - - /** - * Defines the EMR cluster's region - * See https://docs.aws.amazon.com/general/latest/gr/rande.html - * - * @param region the EMR cluster's region - * @return an EMR cluster builder - */ - public Builder awsRegion(String region) { - this.sparkAwsRegion = region; - return this; - } - - /** - * Defines the EMR release version to be used in this cluster - * uses a release label - * See https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-4.2.0/emr-release-differences.html#emr-release-label - * - * @param releaseLabel the EMR release label - * @return an EM cluster Builder - */ - public Builder emrRelease(String releaseLabel) { - this.sparkEmrRelease = releaseLabel; - return this; - } - - /** - * Defines the IAM role to be assumed by the EMR service - *

- * https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-service.html - * - * @param serviceRole the service role - * @return an EM cluster Builder - */ - public Builder emrServiceRole(String serviceRole) { - this.sparkEmrServiceRole = serviceRole; - return this; - } - - /** - * A list of configuration parameters to apply to EMR instances. - * - * @param configs the EMR configurations to apply to this cluster - * @return an EMR cluster builder - */ - public Builder emrConfigs(List configs) { - this.sparkEmrConfigs = configs; - return this; - } - - /** - * The id of the EC2 subnet to be used for this Spark EMR service - * see https://docs.aws.amazon.com/AmazonVPC/latest/UserGuide/VPC_Subnets.html - * - * @param id the subnet ID - * @return an EMR cluster builder - */ - public Builder subnetId(String id) { - this.sparkSubNetid = id; - return this; - } - - /** - * The id of additional security groups this deployment should adopt for both master and slaves - * - * @param securityGroups - * @return an EMR cluster builder - */ - public Builder securityGroupIDs(List securityGroups) { - this.sparkSecurityGroupIds = securityGroups; - return this; - } - - /** - * The number of instances this deployment should comprise of - * - * @param count the number of instances for this cluster - * @rturn an EMR cluster buidler - */ - public Builder instanceCount(int count) { - this.sparkInstanceCount = count; - return this; - } - - /** - * The type of instance this cluster should comprise of - * See https://aws.amazon.com/ec2/instance-types/ - * - * @param instanceType the type of instance for this cluster - * @return an EMR cluster builder - */ - public Builder instanceType(String instanceType) { - this.sparkInstanceType = instanceType; - return this; - } - - /** - * The optional bid value for this cluster's spot instances - * see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html - * Uses the on-demand market if empty. - * - * @param optBid the Optional bid price for this cluster's instnces - * @return an EMR cluster Builder - */ - public Builder instanceBidPrice(Optional optBid) { - this.sparkInstanceBidPrice = optBid; - return this; - } - - /** - * The EC2 instance role that this cluster's instances should assume - * see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html - * - * @param role the intended instance role - * @return an EMR cluster builder - */ - public Builder instanceRole(String role) { - this.sparkInstanceRole = role; - return this; - } - - /** - * the S3 folder in which to find the application jar - * - * @param jarfolder the S3 folder in which to find a jar - * @return an EMR cluster builder - */ - public Builder s3JarFolder(String jarfolder) { - this.sparkS3JarFolder = jarfolder; - return this; - } - - /** - * The timeout duration for this Spark EMR cluster, in minutes - * - * @param timeoutMinutes - * @return an EMR cluster builder - */ - public Builder sparkTimeOutDurationMinutes(int timeoutMinutes) { - this.sparkTimeoutDurationMinutes = timeoutMinutes; - return this; - } - - /** - * Creates an EMR Spark cluster deployment - * - * @return a SparkEMRClient - */ - public SparkEMRClient build() { - this.sparkEmrClientBuilder = AmazonElasticMapReduceClientBuilder.standard().withRegion(sparkAwsRegion); - this.sparkS3ClientBuilder = AmazonS3ClientBuilder.standard().withRegion(sparkAwsRegion); - // note this will be kept alive without steps, an arbitrary choice to avoid rapid test-teardown-restart cycles - this.sparkJobFlowInstancesConfig = (new JobFlowInstancesConfig()).withKeepJobFlowAliveWhenNoSteps(true); - if (this.sparkSubNetid != null) - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withEc2SubnetId(this.sparkSubNetid); - if (!this.sparkSecurityGroupIds.isEmpty()) { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalMasterSecurityGroups(this.sparkSecurityGroupIds); - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalSlaveSecurityGroups(this.sparkSecurityGroupIds); - } - - InstanceGroupConfig masterConfig = - (new InstanceGroupConfig()).withInstanceCount(1).withInstanceRole(InstanceRoleType.MASTER).withInstanceType(sparkInstanceType); - if (sparkInstanceBidPrice.isPresent()) { - masterConfig = masterConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString()); - } else { - masterConfig = masterConfig.withMarket(MarketType.ON_DEMAND); - } - - int slaveCount = sparkInstanceCount - 1; - InstanceGroupConfig slaveConfig = - (new InstanceGroupConfig()).withInstanceCount(slaveCount).withInstanceRole(InstanceRoleType.CORE).withInstanceRole(sparkInstanceType); - if (sparkInstanceBidPrice.isPresent()) { - slaveConfig = slaveConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString()); - } else { - slaveConfig = slaveConfig.withMarket(MarketType.ON_DEMAND); - } - if (slaveCount > 0) { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(Arrays.asList(masterConfig, slaveConfig)); - } else { - this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(slaveConfig); - } - - this.sparkRunJobFlowRequest = new RunJobFlowRequest(); - if (!sparkEmrConfigs.isEmpty()) { - List emrConfigs = new ArrayList<>(); - for (EmrConfig config : sparkEmrConfigs) { - emrConfigs.add(config.toAwsConfig()); - } - this.sparkRunJobFlowRequest = this.sparkRunJobFlowRequest.withConfigurations(emrConfigs); - } - this.sparkRunJobFlowRequest = - this.sparkRunJobFlowRequest.withName(sparkClusterName).withApplications((new Application()).withName("Spark")) - .withReleaseLabel(sparkEmrRelease) - .withServiceRole(sparkEmrServiceRole) - .withJobFlowRole(sparkInstanceRole) - .withInstances(this.sparkJobFlowInstancesConfig); - - return new SparkEMRClient( - this.sparkClusterName, - this.sparkAwsRegion, - this.sparkEmrRelease, - this.sparkEmrServiceRole, - this.sparkEmrConfigs, - this.sparkSubNetid, - this.sparkSecurityGroupIds, - this.sparkInstanceCount, - this.sparkInstanceType, - this.sparkInstanceBidPrice, - this.sparkInstanceRole, - this.sparkS3JarFolder, - this.sparkTimeoutDurationMinutes, - this.sparkEmrClientBuilder, - this.sparkS3ClientBuilder, - this.sparkJobFlowInstancesConfig, - this.sparkRunJobFlowRequest, - this.sparkS3PutObjectDecorator, - this.sparkSubmitConfs - ); - } - - } - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java deleted file mode 100755 index 13175505b..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/BaseS3.java +++ /dev/null @@ -1,117 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3; - -import com.amazonaws.auth.AWSCredentials; -import com.amazonaws.auth.BasicAWSCredentials; -import com.amazonaws.auth.PropertiesCredentials; -import com.amazonaws.services.ec2.AmazonEC2; -import com.amazonaws.services.ec2.AmazonEC2Client; -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3Client; - -import java.io.File; -import java.io.InputStream; - - -/** - * The S3 Credentials works via discovering the credentials - * from the system properties (passed in via -D or System wide) - * If you invoke the JVM with -Dorg.deeplearning4j.aws.accessKey=YOUR_ACCESS_KEY - * and -Dorg.deeplearning4j.aws.accessSecret=YOUR_SECRET_KEY - * this will pick up the credentials from there, otherwise it will also attempt to look in - * the system environment for the following variables: - * - * - * AWS_ACCESS_KEY_ID - * AWS_SECRET_ACCESS_KEY - * @author Adam Gibson - * - */ -public abstract class BaseS3 { - - - /** - * - */ - protected static final long serialVersionUID = -2280107690193651289L; - protected String accessKey; - protected String secretKey; - protected AWSCredentials creds; - public final static String ACCESS_KEY = "org.deeplearning4j.aws.accessKey"; - public final static String ACCESS_SECRET = "org.deeplearning4j.aws.accessSecret"; - public final static String AWS_ACCESS_KEY = "AWS_ACCESS_KEY"; //"AWS_ACCESS_KEY_ID"; - public final static String AWS_SECRET_KEY = "AWS_SECRET_KEY"; //"AWS_SECRET_ACCESS_KEY"; - - - protected void findCreds() { - if (System.getProperty(ACCESS_KEY) != null && System.getProperty(ACCESS_SECRET) != null) { - accessKey = System.getProperty(ACCESS_KEY); - secretKey = System.getProperty(ACCESS_SECRET); - } - - else if (System.getenv(AWS_ACCESS_KEY) != null && System.getenv(AWS_SECRET_KEY) != null) { - accessKey = System.getenv(AWS_ACCESS_KEY); - secretKey = System.getenv(AWS_SECRET_KEY); - } - } - - public BaseS3() { - findCreds(); - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - if (creds == null) - throw new IllegalStateException("Unable to find ec2 credentials"); - } - - public BaseS3(File file) throws Exception { - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - else - creds = new PropertiesCredentials(file); - - - } - - public BaseS3(InputStream is) throws Exception { - if (accessKey != null && secretKey != null) - creds = new BasicAWSCredentials(accessKey, secretKey); - else - creds = new PropertiesCredentials(is); - - - } - - public AWSCredentials getCreds() { - return creds; - } - - public void setCreds(AWSCredentials creds) { - this.creds = creds; - } - - public AmazonS3 getClient() { - return new AmazonS3Client(creds); - } - - public AmazonEC2 getEc2() { - - return new AmazonEC2Client(creds); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java deleted file mode 100755 index 5532f79d2..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BaseS3DataSetIterator.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; - -import java.io.InputStream; -import java.util.Iterator; -import java.util.List; - -/** - * baseline data applyTransformToDestination iterator for - * @author Adam Gibson - * - */ -public abstract class BaseS3DataSetIterator implements DataSetIterator { - - /** - * - */ - private static final long serialVersionUID = 885205002006822431L; - private S3Downloader downloader; - private List buckets; - private int currBucket; - private Iterator currIterator; - - public BaseS3DataSetIterator() { - downloader = new S3Downloader(); - buckets = downloader.buckets(); - currBucket = 0; - currIterator = downloader.iterateBucket(buckets.get(currBucket)); - } - - - - public InputStream nextObject() { - if (currIterator.hasNext()) - return currIterator.next(); - else if (currBucket < buckets.size()) { - currBucket++; - currIterator = downloader.iterateBucket(buckets.get(currBucket)); - return currIterator.next(); - } - - return null; - } - - - - @Override - public boolean hasNext() { - return currBucket < buckets.size() && currIterator.hasNext(); - } - - - - public String currBucket() { - return buckets.get(currBucket); - } - - - - public void nextBucket() { - currBucket++; - } - - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java deleted file mode 100755 index 228fc5cfc..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketIterator.java +++ /dev/null @@ -1,93 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import com.amazonaws.services.s3.model.ObjectListing; -import com.amazonaws.services.s3.model.S3ObjectSummary; - -import java.io.InputStream; -import java.util.Iterator; -import java.util.List; - -/** - * Iterator over individual S3 bucket - * @author Adam Gibson - * - */ -public class BucketIterator implements Iterator { - - private S3Downloader s3; - private String bucket; - private ObjectListing currList; - private List currObjects; - private int currObject; - - - - public BucketIterator(String bucket) { - this(bucket, null); - - } - - - public BucketIterator(String bucket, S3Downloader s3) { - this.bucket = bucket; - - if (s3 == null) - this.s3 = new S3Downloader(); - else - this.s3 = s3; - currList = this.s3.listObjects(bucket); - currObjects = currList.getObjectSummaries(); - - } - - - - @Override - public boolean hasNext() { - return currObject < currObjects.size(); - } - - @Override - public InputStream next() { - if (currObject < currObjects.size()) { - InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey()); - currObject++; - return ret; - } else if (currList.isTruncated()) { - currList = s3.nextList(currList); - currObjects = currList.getObjectSummaries(); - currObject = 0; - - InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey()); - - currObject++; - return ret; - } - - - throw new IllegalStateException("Indeterminate state"); - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java deleted file mode 100755 index 95c356c78..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/BucketKeyListener.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import com.amazonaws.services.s3.AmazonS3; - -/** - * When paginating through a result applyTransformToDestination, - * allows the user to react to a bucket result being found - * @author Adam Gibson - * - */ -public interface BucketKeyListener { - - /** - * - * @param s3 an s3 client - * @param bucket the bucket being iterated on - * @param key the current key - */ - void onKey(AmazonS3 s3, String bucket, String key); - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java deleted file mode 100755 index 980a3f3e9..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/reader/S3Downloader.java +++ /dev/null @@ -1,178 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.reader; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.services.s3.transfer.MultipleFileDownload; -import com.amazonaws.services.s3.transfer.TransferManager; -import org.apache.commons.io.IOUtils; -import org.deeplearning4j.aws.s3.BaseS3; - -import java.io.*; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -/** - * Downloads files from S3 - * @author Adam Gibson - * - */ -public class S3Downloader extends BaseS3 { - - - /** - * Return the keys for a bucket - * @param bucket the bucket to get the keys for - * @return the bucket's keys - */ - public List keysForBucket(String bucket) { - AmazonS3 s3 = getClient(); - List ret = new ArrayList<>(); - ListObjectsRequest listObjectsRequest = new ListObjectsRequest().withBucketName(bucket); - ObjectListing objectListing; - - do { - objectListing = s3.listObjects(listObjectsRequest); - for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) { - ret.add(objectSummary.getKey()); - } - listObjectsRequest.setMarker(objectListing.getNextMarker()); - } while (objectListing.isTruncated()); - - return ret; - } - - /** - * Returns the list of buckets in s3 - * @return the list of buckets - */ - public List buckets() { - List ret = new ArrayList<>(); - AmazonS3 s3 = getClient(); - List buckets = s3.listBuckets(); - for (Bucket b : buckets) - ret.add(b.getName()); - return ret; - } - - /** - * Iterate over individual buckets. - * Returns input streams to each object. - * It is your responsibility to close the input streams - * @param bucket the bucket to iterate over - * @return an iterator over the objects in an s3 bucket - */ - public Iterator iterateBucket(String bucket) { - return new BucketIterator(bucket, this); - } - - /** - * Iterator style one list at a time - * @param list the list to get the next batch for - * @return the next batch of objects or null if - * none are left - */ - public ObjectListing nextList(ObjectListing list) { - AmazonS3 s3 = getClient(); - if (list.isTruncated()) - return s3.listNextBatchOfObjects(list); - return null; - } - - /** - * Simple way of retrieving the listings for a bucket - * @param bucket the bucket to retrieve listings for - * @return the object listing for this bucket - */ - public ObjectListing listObjects(String bucket) { - AmazonS3 s3 = getClient(); - ObjectListing list = s3.listObjects(bucket); - return list; - } - - /** - * Paginates through a bucket's keys invoking the listener - * at each key - * @param bucket the bucket to iterate - * @param listener the listener - */ - public void paginate(String bucket, BucketKeyListener listener) { - AmazonS3 s3 = getClient(); - ObjectListing list = s3.listObjects(bucket); - for (S3ObjectSummary summary : list.getObjectSummaries()) { - if (listener != null) - listener.onKey(s3, bucket, summary.getKey()); - } - - while (list.isTruncated()) { - list = s3.listNextBatchOfObjects(list); - for (S3ObjectSummary summary : list.getObjectSummaries()) { - if (listener != null) - listener.onKey(s3, bucket, summary.getKey()); - } - } - - - } - - - /** - * Returns an input stream for the given bucket and key - * @param bucket the bucket to retrieve from - * @param key the key of the objec t - * @return an input stream to the object - */ - public InputStream objectForKey(String bucket, String key) { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - return is; - } - - - public void download(String bucket, String key, File to) throws IOException { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to)); - IOUtils.copy(is, bos); - bos.close(); - is.close(); - obj.close(); - } - - public void download(String bucket, String key, OutputStream to) throws IOException { - AmazonS3 s3 = getClient(); - S3Object obj = s3.getObject(bucket, key); - InputStream is = obj.getObjectContent(); - BufferedOutputStream bos = new BufferedOutputStream(to); - - IOUtils.copy(is, bos); - bos.close(); - is.close(); - obj.close(); - } - - public MultipleFileDownload downloadFolder(String bucketName, String keyPrefix, File folderPath) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.downloadDirectory(bucketName, keyPrefix, folderPath); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java deleted file mode 100755 index eacc71386..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/java/org/deeplearning4j/aws/s3/uploader/S3Uploader.java +++ /dev/null @@ -1,172 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.aws.s3.uploader; - -import com.amazonaws.services.s3.AmazonS3; -import com.amazonaws.services.s3.AmazonS3Client; -import com.amazonaws.services.s3.model.*; -import com.amazonaws.services.s3.transfer.MultipleFileUpload; -import com.amazonaws.services.s3.transfer.TransferManager; -import org.deeplearning4j.aws.s3.BaseS3; - -import java.io.File; -import java.util.ArrayList; -import java.util.List; - -/** - * Uploads files to S3 - * - * @see {@link BaseS3} - * @author Adam Gibson - * - */ -public class S3Uploader extends BaseS3 { - - - /** - * Multi part upload for big files - * @param file the file to upload - * @param bucketName the bucket name to upload - */ - public void multiPartUpload(File file, String bucketName) { - AmazonS3 client = new AmazonS3Client(creds); - bucketName = ensureValidBucketName(bucketName); - - List buckets = client.listBuckets(); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - doMultiPart(client, bucketName, file); - return; - } - - //bucket didn't exist: create it - client.createBucket(bucketName); - doMultiPart(client, bucketName, file); - } - - /** - * Upload the file to the bucket. - * Will create the bucket if it hasn't already been created - * @param file the file to upload - * @param bucketName the name of the bucket - */ - public void upload(File file, String bucketName) { - AmazonS3 client = new AmazonS3Client(creds); - bucketName = ensureValidBucketName(bucketName); - - List buckets = client.listBuckets(); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - client.putObject(bucketName, file.getName(), file); - return; - } - - //bucket didn't exist: create it - client.createBucket(bucketName); - client.putObject(bucketName, file.getName(), file); - - } - - private void doMultiPart(AmazonS3 s3Client, String bucketName, File file) { - // Create a list of UploadPartResponse objects. You get one of these - // for each part upload. - List partETags = new ArrayList<>(); - - // Step 1: Initialize. - InitiateMultipartUploadRequest initRequest = new InitiateMultipartUploadRequest(bucketName, file.getName()); - InitiateMultipartUploadResult initResponse = s3Client.initiateMultipartUpload(initRequest); - - long contentLength = file.length(); - long partSize = 5242880; // Set part size to 5 MB. - - try { - // Step 2: Upload parts. - long filePosition = 0; - for (int i = 1; filePosition < contentLength; i++) { - // Last part can be less than 5 MB. Adjust part size. - partSize = Math.min(partSize, (contentLength - filePosition)); - - // Create request to upload a part. - UploadPartRequest uploadRequest = new UploadPartRequest().withBucketName(bucketName) - .withKey(file.getName()).withUploadId(initResponse.getUploadId()).withPartNumber(i) - .withFileOffset(filePosition).withFile(file).withPartSize(partSize); - - // Upload part and add response to our list. - partETags.add(s3Client.uploadPart(uploadRequest).getPartETag()); - - filePosition += partSize; - } - - // Step 3: Complete. - CompleteMultipartUploadRequest compRequest = new CompleteMultipartUploadRequest(bucketName, file.getName(), - initResponse.getUploadId(), partETags); - - s3Client.completeMultipartUpload(compRequest); - } catch (Exception e) { - s3Client.abortMultipartUpload( - new AbortMultipartUploadRequest(bucketName, file.getName(), initResponse.getUploadId())); - } - } - - private String ensureValidBucketName(String bucketName) { - String formatted = bucketName.replaceAll("\\s+", "_"); - int length = bucketName.length(); - if (length >= 62) - length = 62; - formatted = formatted.substring(0, length); - formatted = formatted.replace(".", "d"); - formatted = formatted.toLowerCase(); - if (formatted.endsWith("-")) - formatted = formatted.substring(0, length - 1); - - return formatted; - } - - public void upload(File file, String name, String bucketName) { - AmazonS3 client = getClient(); - bucketName = ensureValidBucketName(bucketName); - List buckets = client.listBuckets(); - // ObjectMetadata med = new ObjectMetadata(); - // med.setContentLength(fileLength); - for (Bucket b : buckets) - if (b.getName().equals(bucketName)) { - //client.putObject(bucketName, name, is, med); - client.putObject(new PutObjectRequest(bucketName, name, file)); - return; - } - - //bucket didn't exist: createComplex it - client.createBucket(bucketName); - //client.putObject(bucketName, name, is, med); - client.putObject(new PutObjectRequest(bucketName, name, file)); - } - - - public MultipleFileUpload uploadFolder(String bucketName, String keyPrefix, File folderPath, - boolean includeSubDir) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.uploadDirectory(bucketName, keyPrefix, folderPath, includeSubDir); - } - - public MultipleFileUpload uploadFileList(String bucketName, File folderPath, List fileList, - String keyPrefix) { - TransferManager transfer = new TransferManager(getClient()); - return transfer.uploadFileList(bucketName, keyPrefix, folderPath, fileList); - } - - -} diff --git a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh b/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh deleted file mode 100755 index 7f7285bb5..000000000 --- a/deeplearning4j/deeplearning4j-scaleout/deeplearning4j-aws/src/main/resources/provision-master.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/bin/bash - -################################################################################ -# Copyright (c) 2015-2018 Skymind, Inc. -# -# This program and the accompanying materials are made available under the -# terms of the Apache License, Version 2.0 which is available at -# https://www.apache.org/licenses/LICENSE-2.0. -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT -# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the -# License for the specific language governing permissions and limitations -# under the License. -# -# SPDX-License-Identifier: Apache-2.0 -################################################################################ - -sudo yum -y install blas java-1.7.0-openjdk-devel - -if [ ! -f "/opt/dl4j" ]; -then - sudo mkdir /opt/dl4j - sudo yum -y install git - git clone https://github.com/agibsonccc/java-deeplearning - - wget http://www.trieuvan.com/apache/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.tar.gz - sudo mv apache-maven-3.1.1-bin.tar.gz /opt - cd /opt && sudo tar xvf apache-maven-3.1.1-bin.tar.gz && sudo mv apache-maven-3.1.1 /opt/mvn - cd /home/ec2-user/java-deeplearning/ && /opt/mvn/bin/mvn -DskipTests clean install - echo "Printing distribution" - ls /home/ec2-user/java-deeplearning/deeplearning4j-distribution/target - echo "Before moving distribution" - sudo mv deeplearning4j-distribution/target/deeplearning4j-dist.tar.gz /opt - echo "Moving distribution to opt directory..." - echo "Moving in to opt directory" - cd /opt - - sudo tar xzvf deeplearning4j-dist.tar.gz - #sudo mkdir /opt/dl4j - echo "Moving jars in to /opt/dl4j/..." - sudo mv *.jar /opt/dl4j -fi - - diff --git a/deeplearning4j/deeplearning4j-scaleout/pom.xml b/deeplearning4j/deeplearning4j-scaleout/pom.xml index 2c192cc10..539aa3ef7 100644 --- a/deeplearning4j/deeplearning4j-scaleout/pom.xml +++ b/deeplearning4j/deeplearning4j-scaleout/pom.xml @@ -28,7 +28,6 @@ DeepLearning4j-scaleout-parent - deeplearning4j-aws spark deeplearning4j-scaleout-parallelwrapper deeplearning4j-scaleout-parallelwrapper-parameter-server diff --git a/deeplearning4j/deeplearning4j-util/pom.xml b/deeplearning4j/deeplearning4j-util/pom.xml deleted file mode 100644 index b49239a9e..000000000 --- a/deeplearning4j/deeplearning4j-util/pom.xml +++ /dev/null @@ -1,58 +0,0 @@ - - - - - deeplearning4j-parent - org.deeplearning4j - 1.0.0-SNAPSHOT - - 4.0.0 - - deeplearning4j-util - jar - - deeplearning4j-util - http://maven.apache.org - - - UTF-8 - - - - - org.nd4j - nd4j-api - ${nd4j.version} - - - org.nd4j - nd4j-common - ${nd4j.version} - - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java b/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java deleted file mode 100644 index 69999abf7..000000000 --- a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/DiskBasedQueue.java +++ /dev/null @@ -1,193 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.util; - -import org.apache.commons.io.FileUtils; -import org.nd4j.linalg.util.SerializationUtils; - -import java.io.File; -import java.io.IOException; -import java.io.Serializable; -import java.util.Collection; -import java.util.Iterator; -import java.util.Queue; -import java.util.UUID; -import java.util.concurrent.ConcurrentLinkedDeque; -import java.util.concurrent.Executors; -import java.util.concurrent.atomic.AtomicBoolean; - -/** - * Naive disk based queue for storing items on disk. - * Only meant for poll and adding items. - * @author Adam Gibson - */ -public class DiskBasedQueue implements Queue, Serializable { - - private File dir; - private Queue paths = new ConcurrentLinkedDeque<>(); - private AtomicBoolean running = new AtomicBoolean(true); - private Queue save = new ConcurrentLinkedDeque<>(); - - public DiskBasedQueue() { - this(".queue"); - } - - public DiskBasedQueue(String path) { - this(new File(path)); - - } - - public DiskBasedQueue(File dir) { - this.dir = dir; - if (!dir.exists() && dir.isDirectory()) { - throw new IllegalArgumentException("Illegal queue: must be a directory"); - } - - if (!dir.exists()) - dir.mkdirs(); - if (dir.listFiles() != null && dir.listFiles().length > 1) - try { - FileUtils.deleteDirectory(dir); - } catch (IOException e) { - e.printStackTrace(); - } - - - dir.mkdir(); - - Thread t = Executors.defaultThreadFactory().newThread(new Runnable() { - @Override - public void run() { - while (running.get()) { - while (!save.isEmpty()) - addAndSave(save.poll()); - - ThreadUtils.uncheckedSleep(1000); - } - } - }); - t.setName("DiskBasedQueueSaver"); - t.setDaemon(true); - t.start(); - } - - @Override - public int size() { - return paths.size(); - } - - @Override - public boolean isEmpty() { - return paths.isEmpty(); - } - - @Override - public boolean contains(Object o) { - throw new UnsupportedOperationException(); - - } - - @Override - public Iterator iterator() { - throw new UnsupportedOperationException(); - - } - - @Override - public Object[] toArray() { - throw new UnsupportedOperationException(); - - } - - @Override - public T[] toArray(T[] a) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean add(E e) { - save.add(e); - return true; - } - - @Override - public boolean remove(Object o) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean containsAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean addAll(Collection c) { - for (E e : c) - addAndSave(e); - return true; - } - - @Override - public boolean removeAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public boolean retainAll(Collection c) { - throw new UnsupportedOperationException(); - } - - @Override - public void clear() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean offer(E e) { - throw new UnsupportedOperationException(); - } - - @Override - public E remove() { - throw new UnsupportedOperationException(); - } - - @Override - public E poll() { - String path = paths.poll(); - E ret = SerializationUtils.readObject(new File(path)); - File item = new File(path); - item.delete(); - return ret; - } - - @Override - public E element() { - throw new UnsupportedOperationException(); - } - - @Override - public E peek() { - throw new UnsupportedOperationException(); - } - - private void addAndSave(E e) { - File path = new File(dir, UUID.randomUUID().toString()); - SerializationUtils.saveObject(e, path); - paths.add(path.getAbsolutePath()); - } -} diff --git a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java b/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java deleted file mode 100755 index b620a8c21..000000000 --- a/deeplearning4j/deeplearning4j-util/src/main/java/org/deeplearning4j/util/Dl4jReflection.java +++ /dev/null @@ -1,132 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.deeplearning4j.util; - - -import java.lang.reflect.Constructor; -import java.lang.reflect.Field; -import java.lang.reflect.Modifier; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; -import java.util.Properties; - -/** - * @author Adam Gibson - */ -public class Dl4jReflection { - private Dl4jReflection() {} - - /** - * Gets the empty constructor from a class - * @param clazz the class to get the constructor from - * @return the empty constructor for the class - */ - public static Constructor getEmptyConstructor(Class clazz) { - Constructor c = clazz.getDeclaredConstructors()[0]; - for (int i = 0; i < clazz.getDeclaredConstructors().length; i++) { - if (clazz.getDeclaredConstructors()[i].getParameterTypes().length < 1) { - c = clazz.getDeclaredConstructors()[i]; - break; - } - } - - return c; - } - - - public static Field[] getAllFields(Class clazz) { - // Keep backing up the inheritance hierarchy. - Class targetClass = clazz; - List fields = new ArrayList<>(); - - do { - fields.addAll(Arrays.asList(targetClass.getDeclaredFields())); - targetClass = targetClass.getSuperclass(); - } while (targetClass != null && targetClass != Object.class); - - return fields.toArray(new Field[fields.size()]); - } - - /** - * Sets the properties of the given object - * @param obj the object o set - * @param props the properties to set - */ - public static void setProperties(Object obj, Properties props) throws Exception { - for (Field field : obj.getClass().getDeclaredFields()) { - field.setAccessible(true); - if (props.containsKey(field.getName())) { - set(field, obj, props.getProperty(field.getName())); - } - - } - } - - /* sets a field with a fairly basic strategy */ - private static void set(Field field, Object obj, String value) throws Exception { - Class clazz = field.getType(); - field.setAccessible(true); - if (clazz.equals(Double.class) || clazz.equals(double.class)) { - double val = Double.valueOf(value); - field.set(obj, val); - } else if (clazz.equals(String.class)) { - field.set(obj, value); - } else if (clazz.equals(Integer.class) || clazz.equals(int.class)) { - int val = Integer.parseInt(value); - field.set(obj, val); - } else if (clazz.equals(Float.class) || clazz.equals(float.class)) { - float f = Float.parseFloat(value); - field.set(obj, f); - } - } - - - /** - * Get fields as properties - * @param obj the object to get fields for - * @param clazzes the classes to use for reflection and properties. - * T - * @return the fields as properties - */ - public static Properties getFieldsAsProperties(Object obj, Class[] clazzes) throws Exception { - Properties props = new Properties(); - for (Field field : obj.getClass().getDeclaredFields()) { - if (Modifier.isStatic(field.getModifiers())) - continue; - field.setAccessible(true); - Class type = field.getType(); - if (clazzes == null || contains(type, clazzes)) { - Object val = field.get(obj); - if (val != null) - props.put(field.getName(), val.toString()); - - } - } - - return props; - } - - - private static boolean contains(Class test, Class[] arr) { - for (Class c : arr) - if (c.equals(test)) - return true; - return false; - } - -} diff --git a/deeplearning4j/dl4j-perf/pom.xml b/deeplearning4j/dl4j-perf/pom.xml deleted file mode 100644 index 1e1dd07f9..000000000 --- a/deeplearning4j/dl4j-perf/pom.xml +++ /dev/null @@ -1,135 +0,0 @@ - - - - - - - deeplearning4j-parent - org.deeplearning4j - 1.0.0-SNAPSHOT - - 4.0.0 - - dl4j-perf - - dl4j-perf - - - UTF-8 - 1.7 - 1.7 - - - - - org.slf4j - slf4j-api - - - com.github.oshi - oshi-json - ${oshi.version} - - - org.deeplearning4j - deeplearning4j-nn - ${project.version} - - - com.github.oshi - oshi-core - ${oshi.version} - - - - junit - junit - - - - org.projectlombok - lombok - ${lombok.version} - provided - - - - ch.qos.logback - logback-classic - test - - - org.deeplearning4j - deeplearning4j-datasets - ${project.version} - test - - - - org.deeplearning4j - deeplearning4j-common-tests - ${project.version} - test - - - - - - - - maven-clean-plugin - 3.0.0 - - - - maven-resources-plugin - 3.0.2 - - - maven-compiler-plugin - 3.7.0 - - - maven-surefire-plugin - 2.20.1 - - - maven-jar-plugin - 3.0.2 - - - maven-install-plugin - 2.5.2 - - - maven-deploy-plugin - 2.8.2 - - - - - - - - test-nd4j-native - - - test-nd4j-cuda-10.2 - - - diff --git a/deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml b/deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml deleted file mode 100644 index c6f89b60a..000000000 --- a/deeplearning4j/dl4j-perf/src/test/resources/logback-test.xml +++ /dev/null @@ -1,50 +0,0 @@ - - - - - - logs/application.log - - %date - [%level] - from %logger in %thread - %n%message%n%xException%n - - - - - - %logger{15} - %message%n%xException{5} - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/deeplearning4j/pom.xml b/deeplearning4j/pom.xml index 15be211de..bcf633e3a 100644 --- a/deeplearning4j/pom.xml +++ b/deeplearning4j/pom.xml @@ -140,8 +140,6 @@ deeplearning4j-nearestneighbors-parent deeplearning4j-data deeplearning4j-manifold - deeplearning4j-util - dl4j-perf dl4j-integration-tests deeplearning4j-common deeplearning4j-remote diff --git a/libnd4j/tests_cpu/run_tests.sh b/libnd4j/tests_cpu/run_tests.sh index 2932827d4..9b1271df6 100755 --- a/libnd4j/tests_cpu/run_tests.sh +++ b/libnd4j/tests_cpu/run_tests.sh @@ -39,6 +39,7 @@ do done CHIP="${CHIP:-cpu}" +export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" # On Mac, make sure it can find libraries for GCC export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/ @@ -47,8 +48,9 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/ if [ -n "$BUILD_PATH" ]; then if which cygpath; then BUILD_PATH=$(cygpath -p $BUILD_PATH) + export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'" fi export PATH="$PATH:$BUILD_PATH" fi -../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests --gtest_output="xml:../target/surefire-reports/TEST-${CHIP}-results.xml" +../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml index b4a374baf..9a42f6bd0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/pom.xml @@ -217,16 +217,6 @@ commons-net ${commons-net.version} - - org.nd4j - nd4j-buffer - ${project.version} - - - org.nd4j - nd4j-context - ${project.version} - net.ericaro neoitertools @@ -238,6 +228,16 @@ + + org.nd4j + nd4j-common + ${project.version} + + + org.bytedeco + javacpp + ${javacpp.version} + diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/context/Nd4jContext.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/context/Nd4jContext.java similarity index 100% rename from nd4j/nd4j-context/src/main/java/org/nd4j/context/Nd4jContext.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/context/Nd4jContext.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java index cec374acb..35181bc23 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMax.java @@ -81,11 +81,6 @@ public class BroadcastMax extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "Max"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java index d1d70e8b4..c8cac0b6c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMin.java @@ -81,11 +81,6 @@ public class BroadcastMin extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "Min"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java index ccbc17700..406b3f6f2 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastMulOp.java @@ -75,11 +75,6 @@ public class BroadcastMulOp extends BaseBroadcastOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - return "Mul"; - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java index e060db4b6..f4e93ea22 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/broadcast/BroadcastSubOp.java @@ -75,11 +75,6 @@ public class BroadcastSubOp extends BaseBroadcastOp { throw new NoOpNameFoundException("No onnx op opName found for " + opName()); } - @Override - public String tensorflowName(){ - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - @Override public List doDiff(List f1) { return null; diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java index 8526a251b..008a065ef 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/reduce/TensorMmul.java @@ -16,6 +16,8 @@ package org.nd4j.linalg.api.ops.impl.reduce; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; import lombok.NoArgsConstructor; @@ -325,7 +327,8 @@ public class TensorMmul extends DynamicCustomOp { } @Override - public String tensorflowName() { - return "MatMul"; + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input data types for %s, got %s", getClass(), inputDataTypes); + return Collections.singletonList(inputDataTypes.get(0)); } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java index b0090d047..12c852949 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/RSubOp.java @@ -61,16 +61,10 @@ public class RSubOp extends BaseDynamicTransformOp { throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName()); } - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No TensorFlow op name found for: " + getClass().getName()); - } - public RSubOp( INDArray[] inputs, INDArray[] outputs) { super(inputs, outputs); } - @Override public List doDiff(List i_v) { return f().rsubBp(larg(), rarg(), i_v.get(0)); diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/BaseDataBuffer.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/BaseDataBuffer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/BaseDataBuffer.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataBuffer.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataBuffer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataBuffer.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataType.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataType.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataType.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataTypeEx.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataTypeEx.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/DataTypeEx.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/DataTypeEx.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/allocation/MemoryStrategy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/allocation/MemoryStrategy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/allocation/MemoryStrategy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/allocation/MemoryStrategy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/factory/DataBufferFactory.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/factory/DataBufferFactory.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/factory/DataBufferFactory.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/AllocUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/AllocUtil.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/AllocUtil.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/AllocUtil.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/DataTypeUtil.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/buffer/util/DataTypeUtil.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/buffer/util/DataTypeUtil.java diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Environment.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Environment.java similarity index 100% rename from nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Environment.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Environment.java diff --git a/nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java similarity index 100% rename from nd4j/nd4j-context/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4jBackend.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/AllocationsTracker.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/AllocationsTracker.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/AllocationsTracker.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocatable.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocatable.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocatable.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocatable.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocator.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocator.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/Deallocator.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/Deallocator.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/DeviceAllocationsTracker.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/DeviceAllocationsTracker.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/DeviceAllocationsTracker.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspace.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspace.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspace.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspace.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspaceManager.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/MemoryWorkspaceManager.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/MemoryWorkspaceManager.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/conf/WorkspaceConfiguration.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/conf/WorkspaceConfiguration.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/conf/WorkspaceConfiguration.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/conf/WorkspaceConfiguration.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationKind.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationKind.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationKind.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/AllocationPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/AllocationPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/DebugMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/DebugMode.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/DebugMode.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/DebugMode.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LearningPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LearningPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LearningPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LearningPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LocationPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LocationPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/LocationPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/LocationPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MemoryKind.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MemoryKind.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MemoryKind.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MemoryKind.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MirroringPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MirroringPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/MirroringPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/MirroringPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/ResetPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/ResetPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/ResetPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/ResetPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/SpillPolicy.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/SpillPolicy.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/enums/SpillPolicy.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/enums/SpillPolicy.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/ImmortalFloatPointer.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/ImmortalFloatPointer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/ImmortalFloatPointer.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PagedPointer.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PagedPointer.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PagedPointer.java diff --git a/nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PointersPair.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PointersPair.java similarity index 100% rename from nd4j/nd4j-buffer/src/main/java/org/nd4j/linalg/api/memory/pointers/PointersPair.java rename to nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/memory/pointers/PointersPair.java diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml index 24621a21c..90748460a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-native-api/pom.xml @@ -30,11 +30,11 @@ - + org.nd4j nd4j-api diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java index 80a7f9560..7ab2e8c61 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/buffer/LongBuffer.java @@ -124,6 +124,7 @@ public class LongBuffer extends BaseCpuDataBuffer { indexer = LongIndexer.create((LongPointer) this.pointer); // we still want this buffer to have native representation + ptrDataBuffer = NativeOpsHolder.getInstance().getDeviceNativeOps().allocateDataBuffer(0, DataType.INT64.toInt(), false); NativeOpsHolder.getInstance().getDeviceNativeOps().dbSetPrimaryBuffer(ptrDataBuffer, this.pointer, numberOfElements); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java index 0fbfa2671..16f097084 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java @@ -30,8 +30,6 @@ import org.nd4j.linalg.factory.Nd4jBackend; import org.nd4j.linalg.ops.transforms.Transforms; import org.nd4j.nativeblas.NativeOpsHolder; -import java.util.Arrays; - import static org.junit.Assert.assertNull; @Slf4j @@ -179,8 +177,8 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test public void testMeanBP_Rank1() { INDArray dLdOut = Nd4j.scalar(0.5); - INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); - INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5/3); + INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); + INDArray dLdInExp = Nd4j.valueArrayOf(new long[]{3}, 0.5 / 3); INDArray dLdIn = Nd4j.createUninitialized(new long[]{3}); @@ -199,7 +197,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { for (boolean keepDims : new boolean[]{false, true}) { long[] reducedShape_0 = (keepDims ? new long[]{1, 4} : new long[]{4}); - INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape('c',3, 4); + INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape('c', 3, 4); INDArray dLdOut_0 = Nd4j.create(new double[]{1, 2, 3, 4}, reducedShape_0); INDArray dLdInExpected_0 = Nd4j.createUninitialized(preReduceInput.shape()); for (int i = 0; i < 3; i++) { @@ -524,7 +522,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test public void testStdevBP_Rank1() { INDArray dLdOut = Nd4j.scalar(0.5); - INDArray preReduceInput = Nd4j.create(new double[]{2,3,4}, new long[]{3}); + INDArray preReduceInput = Nd4j.create(new double[]{2, 3, 4}, new long[]{3}); double stdev = preReduceInput.stdNumber(true).doubleValue(); double mean = preReduceInput.meanNumber().doubleValue(); @@ -577,7 +575,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray dLdInExpected_1 = preReduceInput.dup(); dLdInExpected_1.subiColumnVector(mean_1) .diviColumnVector(stdev_1.mul(divisor)) - .muliColumnVector(dLdOut_1.reshape(3,1)); + .muliColumnVector(dLdOut_1.reshape(3, 1)); dLdIn = Nd4j.createUninitialized(3, 4); err = OpValidation.validate(new OpTestCase(new StandardDeviationBp(preReduceInput, dLdOut_1, dLdIn, biasCorrected, keepDims, 1)) @@ -653,7 +651,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray mean_1 = preReduceInput.mean(1); INDArray dLdInExpected_1 = preReduceInput.dup(); dLdInExpected_1.subiColumnVector(mean_1).muli(2.0 / divisor) - .muliColumnVector(dLdOut_1.reshape(3,1)); + .muliColumnVector(dLdOut_1.reshape(3, 1)); dLdIn = Nd4j.createUninitialized(3, 4); @@ -688,17 +686,16 @@ public class ReductionBpOpValidation extends BaseOpValidation { // = cumSumExclusive(dL/dOut_j) - - for(boolean exclusive : new boolean[]{false, true}) { - for(boolean reverse : new boolean[]{false, true}) { + for (boolean exclusive : new boolean[]{false, true}) { + for (boolean reverse : new boolean[]{false, true}) { INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape(3, 4); - INDArray dLdOut = Nd4j.valueArrayOf(new long[]{3,4}, 0.5); + INDArray dLdOut = Nd4j.valueArrayOf(new long[]{3, 4}, 0.5); INDArray dLdIn = Nd4j.createUninitialized(3, 4); INDArray dLdInExpected; - if(exclusive){ - if(reverse){ + if (exclusive) { + if (reverse) { dLdInExpected = Nd4j.create(new double[][]{ {0.0, 0.0, 0.0, 0.0}, {0.5, 0.5, 0.5, 0.5}, @@ -710,7 +707,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { {0.0, 0.0, 0.0, 0.0}}); } } else { - if(reverse){ + if (reverse) { dLdInExpected = Nd4j.create(new double[][]{ {0.5, 0.5, 0.5, 0.5}, {1.0, 1.0, 1.0, 1.0}, @@ -727,7 +724,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { String err = OpValidation.validate(new OpTestCase( new CumSumBp(preReduceInput, dLdOut, dLdIn, exclusive, reverse, 0)) .expectedOutput(0, dLdInExpected)); - if(err != null){ + if (err != null) { err = err + " - exclusive=" + exclusive + ", reverse=" + reverse; } assertNull(err); @@ -737,7 +734,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { @Test - public void testNorm2Bp(){ + public void testNorm2Bp() { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * x/|x|_2 @@ -797,7 +794,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNorm1Bp(){ + public void testNorm1Bp() { //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * sgn(in) @@ -856,7 +853,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { } @Test - public void testNormMaxBp(){ + public void testNormMaxBp() { //out = max_i (|in_i|) //dL/dIn = dL/dOut * dOut/dIn // = dL/dOut * (0 if |x_i| is not max; or sgn(x_i) otherwise) @@ -866,8 +863,8 @@ public class ReductionBpOpValidation extends BaseOpValidation { INDArray preReduceInput = Nd4j.linspace(-5, 6, 12).reshape(3, 4); INDArray sgn = Transforms.sign(preReduceInput, true); - INDArray max = Nd4j.create(3,4); - max.putScalar(2,3,1.0); + INDArray max = Nd4j.create(3, 4); + max.putScalar(2, 3, 1.0); INDArray dLdOut; if (keepDims) { @@ -896,7 +893,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { long[] reducedShape_0 = (keepDims ? new long[]{1, 4} : new long[]{4}); INDArray preReduceInput = Nd4j.linspace(1, 12, 12).reshape(3, 4); INDArray sgn = Transforms.sign(preReduceInput, true); - INDArray max_0 = Nd4j.create(3,4); + INDArray max_0 = Nd4j.create(3, 4); max_0.getRow(2).assign(1.0); INDArray dLdOut_0 = Nd4j.create(new double[]{1, 2, 3, 4}, reducedShape_0); INDArray dLdInExpected_0 = sgn.mul(max_0).mulRowVector(dLdOut_0); @@ -910,7 +907,7 @@ public class ReductionBpOpValidation extends BaseOpValidation { long[] reducedShape_1 = (keepDims ? new long[]{3, 1} : new long[]{3}); INDArray dLdOut_1 = Nd4j.create(new double[]{1, 2, 3}, reducedShape_1); - INDArray max_1 = Nd4j.create(3,4); + INDArray max_1 = Nd4j.create(3, 4); max_1.getColumn(3).assign(1.0); INDArray dLdInExpected_1 = sgn.mul(max_1).mulColumnVector(dLdOut_1); dLdIn = Nd4j.createUninitialized(3, 4); @@ -921,60 +918,5 @@ public class ReductionBpOpValidation extends BaseOpValidation { assertNull(err, err); } } - - @Test - public void testPowBP() { - - for (boolean keepDims : new boolean[]{false, true}) { - - INDArray preReduceInput_1 = Nd4j.createFromArray(new double[]{ - 4,3,2,5,7,8,-9,-12 - }).reshape(2,2,2); - INDArray preReduceInput_2 = Nd4j.createFromArray(new double[]{ - 2,3,-2,4,-1,-4,10,8 - }).reshape(2,2,2); - INDArray preReduceInput_3 = Nd4j.linspace(1, 8, 8).reshape(2, 2,2); - INDArray gradOutput = Nd4j.valueArrayOf(new long[]{2, 2, 2}, 1.0); - INDArray dLdInExpected_1 = Nd4j.createFromArray(new double[]{ - 8, 27, -0.25, 500, -0.0204082, -0.000122, -3.87420e+09, -2.86654e+08 - }).reshape(2,2,2); - INDArray dLdInExpected_2 = Nd4j.createFromArray(new double[]{ - 22.18071, 29.66253, 0.17329, 1005.89874, 0.27799, 0.00051, 0, 0 - }).reshape(2,2,2); - INDArray output1 = Nd4j.createUninitialized(2, 2,2); - INDArray output2 = Nd4j.createUninitialized(2, 2,2); - - String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2, - gradOutput, output1, output2)) - .expectedOutput(0, dLdInExpected_1).expectedOutput(1, dLdInExpected_2)); - - assertNull(err); - } - } - - @Test - public void testPowBP1() { - - INDArray preReduceInput_1 = Nd4j.createFromArray(new float[]{ - 0.0714f, 0.4735f, -0.1249f, 0.4482f, - -0.1376f, 0.5218f, 0.5558f, 0.2444f, - -0.5297f, 0.4291f, 0.4913f, -0.1178f - }).reshape(3,4); - INDArray preReduceInput_2 = Nd4j.scalar(2.0000f); - - INDArray gradOutput = Nd4j.valueArrayOf(new long[]{3, 4}, 1.0f); - - INDArray output1 = Nd4j.createUninitialized(DataType.FLOAT, 3,4); - INDArray output2 = Nd4j.scalar(DataType.FLOAT, 1.0); //Nd4j.createUninitialized(DataType.FLOAT, 3,4); - - INDArray expected1 = Nd4j.createFromArray(new float[]{ - 0.1428f, 0.9470f, -0.2498f, 0.8964f, - -0.2752f, 1.0436f, 1.1116f, 0.4888f, - -1.0594f, 0.8582f, 0.9826f, -0.2356f - }).reshape(3,4); - INDArray expected2 = Nd4j.scalar(DataType.FLOAT, -1.112316132); - String err = OpValidation.validate(new OpTestCase(new PowBp(preReduceInput_1, preReduceInput_2, - gradOutput, output1, output2)).expectedOutput(0, expected1).expectedOutput(1, expected2)); - assertNull(err); - } } + diff --git a/nd4j/nd4j-buffer/pom.xml b/nd4j/nd4j-buffer/pom.xml deleted file mode 100644 index 72869b16c..000000000 --- a/nd4j/nd4j-buffer/pom.xml +++ /dev/null @@ -1,154 +0,0 @@ - - - - - nd4j - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-buffer - jar - - nd4j-buffer - - - - org.apache.maven.plugins - maven-compiler-plugin - - 7 - 7 - - - - - - - - - - linux - - linux - - - linux - - - - macosx - - mac os x - - - macosx - - - - windows - - windows - - - windows - - - - i386 - - i386 - - - x86_64 - - - - i486 - - i486 - - - x86_64 - - - - i586 - - i586 - - - x86_64 - - - - i686 - - i686 - - - x86_64 - - - - x86 - - x86 - - - x86_64 - - - - amd64 - - amd64 - - - x86_64 - - - - x86-64 - - x86-64 - - - x86_64 - - - - testresources - - - - - - - - org.nd4j - nd4j-context - ${project.version} - - - org.bytedeco - javacpp - ${javacpp.version} - - - diff --git a/nd4j/nd4j-context/pom.xml b/nd4j/nd4j-context/pom.xml deleted file mode 100644 index 225b3784c..000000000 --- a/nd4j/nd4j-context/pom.xml +++ /dev/null @@ -1,48 +0,0 @@ - - - - - nd4j - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-context - jar - - nd4j-context - - - - org.nd4j - nd4j-common - ${project.version} - - - - - 1.7 - 1.7 - - - - - testresources - - - diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java deleted file mode 100644 index 9f30cb549..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/KafkaConnectionInformation.java +++ /dev/null @@ -1,51 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import kafka.serializer.StringEncoder; -import lombok.Builder; -import lombok.Data; - -import java.io.Serializable; - -/** - * Kafka connection information - * to generate camel uris - * - * @author Adam Gibson - */ -@Builder -@Data -public class KafkaConnectionInformation implements Serializable { - private String zookeeperHost; - private int zookeeperPort; - private String kafkaBrokerList; - private String topicName; - private String groupId; - - /** - * Returns a kafka connection uri - * @return a kafka connection uri - * represented by this connection information - */ - public String kafkaUri() { - return String.format( - "kafka://%s?topic=%s&groupId=%s&zookeeperHost=%s&zookeeperPort=%d&serializerClass=%s&keySerializerClass=%s", - kafkaBrokerList, topicName, groupId, zookeeperHost, zookeeperPort, - StringEncoder.class.getName(), StringEncoder.class.getName()); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java deleted file mode 100644 index 7e73e4716..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaConsumer.java +++ /dev/null @@ -1,45 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.CamelContext; -import org.apache.camel.ConsumerTemplate; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * Created by agibsonccc on 7/19/16. - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaConsumer { - private KafkaConnectionInformation connectionInformation; - private ConsumerTemplate consumerTemplate; - private CamelContext camelContext; - - /** - * Receive an ndarray - * @return - */ - public INDArray receive() { - if (consumerTemplate == null) - consumerTemplate = camelContext.createConsumerTemplate(); - return consumerTemplate.receiveBody("direct:receive", INDArray.class); - } - -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java deleted file mode 100644 index 4120abacb..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaProducer.java +++ /dev/null @@ -1,48 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.CamelContext; -import org.apache.camel.ProducerTemplate; -import org.nd4j.linalg.api.ndarray.INDArray; - -/** - * Created by agibsonccc on 7/19/16. - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaProducer { - - private KafkaConnectionInformation connectionInformation; - private CamelContext camelContext; - private ProducerTemplate producerTemplate; - - /** - * Publish to a kafka topic - * based on the connection information - * @param arr - */ - public void publish(INDArray arr) { - if (producerTemplate == null) - producerTemplate = camelContext.createProducerTemplate(); - producerTemplate.sendBody("direct:start", arr); - } - - -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java deleted file mode 100644 index 909db800d..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/java/org/nd4j/camel/kafka/Nd4jKafkaRoute.java +++ /dev/null @@ -1,74 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.camel.kafka; - -import lombok.AllArgsConstructor; -import lombok.Builder; -import org.apache.camel.Exchange; -import org.apache.camel.Processor; -import org.apache.camel.builder.RouteBuilder; -import org.apache.camel.component.kafka.KafkaConstants; -import org.apache.commons.net.util.Base64; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.ByteArrayOutputStream; -import java.io.DataOutputStream; -import java.util.UUID; - -/** - * Sends a test ndarray - * to kafka - * - * @author Adam Gibson - */ -@AllArgsConstructor -@Builder -public class Nd4jKafkaRoute extends RouteBuilder { - private KafkaConnectionInformation kafkaConnectionInformation; - - @Override - public void configure() throws Exception { - final String kafkaUri = kafkaConnectionInformation.kafkaUri(); - from("direct:start").process(new Processor() { - @Override - public void process(Exchange exchange) throws Exception { - final INDArray arr = (INDArray) exchange.getIn().getBody(); - ByteArrayOutputStream bos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(bos); - Nd4j.write(arr, dos); - byte[] bytes = bos.toByteArray(); - String base64 = Base64.encodeBase64String(bytes); - exchange.getIn().setBody(base64, String.class); - String id = UUID.randomUUID().toString(); - exchange.getIn().setHeader(KafkaConstants.KEY, id); - exchange.getIn().setHeader(KafkaConstants.PARTITION_KEY, id); - } - }).to(kafkaUri); - - from(kafkaUri).process(new Processor() { - @Override - public void process(Exchange exchange) throws Exception { - byte[] body2 = (byte[]) exchange.getIn().getBody(); - String body = new String(body2); - INDArray arr = Nd4jBase64.fromBase64(body); - exchange.getIn().setBody(arr); - } - }).to("direct:receive"); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java deleted file mode 100644 index 3e01b1e6c..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedKafkaCluster.java +++ /dev/null @@ -1,177 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import kafka.admin.AdminUtils; -import kafka.server.KafkaConfig; -import kafka.server.KafkaServer; -import org.I0Itec.zkclient.ZkClient; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.File; -import java.io.FileNotFoundException; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Properties; - -public class EmbeddedKafkaCluster { - private static final Logger LOG = LoggerFactory.getLogger(EmbeddedKafkaCluster.class); - - private final List ports; - private final String zkConnection; - private final Properties baseProperties; - - private final String brokerList; - - private final List brokers; - private final List logDirs; - - public EmbeddedKafkaCluster(String zkConnection) { - this(zkConnection, new Properties()); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties) { - this(zkConnection, baseProperties, Collections.singletonList(-1)); - } - - public EmbeddedKafkaCluster(String zkConnection, Properties baseProperties, List ports) { - this.zkConnection = zkConnection; - this.ports = resolvePorts(ports); - this.baseProperties = baseProperties; - this.brokers = new ArrayList(); - this.logDirs = new ArrayList(); - - this.brokerList = constructBrokerList(this.ports); - } - - public ZkClient getZkClient() { - for (KafkaServer server : brokers) { - return server.zkClient(); - } - return null; - } - - public void createTopics(String... topics) { - for (String topic : topics) { - AdminUtils.createTopic(getZkClient(), topic, 2, 1, new Properties()); - } - } - - private List resolvePorts(List ports) { - List resolvedPorts = new ArrayList(); - for (Integer port : ports) { - resolvedPorts.add(resolvePort(port)); - } - return resolvedPorts; - } - - private int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - private String constructBrokerList(List ports) { - StringBuilder sb = new StringBuilder(); - for (Integer port : ports) { - if (sb.length() > 0) { - sb.append(","); - } - sb.append("localhost:").append(port); - } - return sb.toString(); - } - - public void startup() { - for (int i = 0; i < ports.size(); i++) { - Integer port = ports.get(i); - File logDir = TestUtils.constructTempDir("kafka-local"); - - Properties properties = new Properties(); - properties.putAll(baseProperties); - properties.setProperty("zookeeper.connect", zkConnection); - properties.setProperty("broker.id", String.valueOf(i + 1)); - properties.setProperty("host.opName", "localhost"); - properties.setProperty("port", Integer.toString(port)); - properties.setProperty("log.dir", logDir.getAbsolutePath()); - properties.setProperty("num.partitions", String.valueOf(1)); - properties.setProperty("auto.create.topics.enable", String.valueOf(Boolean.TRUE)); - properties.setProperty("log.flush.interval.messages", String.valueOf(1)); - LOG.info("EmbeddedKafkaCluster: local directory: " + logDir.getAbsolutePath()); - - KafkaServer broker = startBroker(properties); - - brokers.add(broker); - logDirs.add(logDir); - } - } - - - private KafkaServer startBroker(Properties props) { - KafkaServer server = new KafkaServer(new KafkaConfig(props), new SystemTime()); - server.startup(); - return server; - } - - public Properties getProps() { - Properties props = new Properties(); - props.putAll(baseProperties); - props.put("metadata.broker.list", brokerList); - props.put("zookeeper.connect", zkConnection); - return props; - } - - public String getBrokerList() { - return brokerList; - } - - public List getPorts() { - return ports; - } - - public String getZkConnection() { - return zkConnection; - } - - public void shutdown() { - for (KafkaServer broker : brokers) { - try { - broker.shutdown(); - } catch (Exception e) { - e.printStackTrace(); - } - } - for (File logDir : logDirs) { - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - e.printStackTrace(); - } - } - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("EmbeddedKafkaCluster{"); - sb.append("brokerList='").append(brokerList).append('\''); - sb.append('}'); - return sb.toString(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java deleted file mode 100644 index a48b37e5f..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/EmbeddedZookeeper.java +++ /dev/null @@ -1,112 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2019 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import org.apache.zookeeper.server.ServerCnxnFactory; -import org.apache.zookeeper.server.ZooKeeperServer; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.InetSocketAddress; - -public class EmbeddedZookeeper { - private int port = -1; - private int tickTime = 500; - - private ServerCnxnFactory factory; - private File snapshotDir; - private File logDir; - - public EmbeddedZookeeper() { - this(-1); - } - - public EmbeddedZookeeper(int port) { - this(port, 500); - } - - public EmbeddedZookeeper(int port, int tickTime) { - this.port = resolvePort(port); - this.tickTime = tickTime; - } - - private int resolvePort(int port) { - if (port == -1) { - return TestUtils.getAvailablePort(); - } - return port; - } - - public void startup() throws IOException { - if (this.port == -1) { - this.port = TestUtils.getAvailablePort(); - } - this.factory = ServerCnxnFactory.createFactory(new InetSocketAddress("localhost", port), 1024); - this.snapshotDir = TestUtils.constructTempDir("embeeded-zk/snapshot"); - this.logDir = TestUtils.constructTempDir("embeeded-zk/log"); - - try { - factory.startup(new ZooKeeperServer(snapshotDir, logDir, tickTime)); - } catch (InterruptedException e) { - throw new IOException(e); - } - } - - - public void shutdown() { - factory.shutdown(); - try { - TestUtils.deleteFile(snapshotDir); - } catch (FileNotFoundException e) { - // ignore - } - try { - TestUtils.deleteFile(logDir); - } catch (FileNotFoundException e) { - // ignore - } - } - - public String getConnection() { - return "localhost:" + port; - } - - public void setPort(int port) { - this.port = port; - } - - public void setTickTime(int tickTime) { - this.tickTime = tickTime; - } - - public int getPort() { - return port; - } - - public int getTickTime() { - return tickTime; - } - - @Override - public String toString() { - final StringBuilder sb = new StringBuilder("EmbeddedZookeeper{"); - sb.append("connection=").append(getConnection()); - sb.append('}'); - return sb.toString(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java deleted file mode 100644 index cb2136309..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/SystemTime.java +++ /dev/null @@ -1,38 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - - -import kafka.utils.Time; - -class SystemTime implements Time { - public long milliseconds() { - return System.currentTimeMillis(); - } - - public long nanoseconds() { - return System.nanoTime(); - } - - public void sleep(long ms) { - try { - Thread.sleep(ms); - } catch (InterruptedException e) { - // Ignore - } - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java b/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java deleted file mode 100644 index a93f18adc..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/nd4j-kafka/src/main/test/java/org/nd4j/kafka/TestUtils.java +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.kafka; - -import java.io.File; -import java.io.FileNotFoundException; -import java.io.IOException; -import java.net.ServerSocket; -import java.util.Random; - -public class TestUtils { - private static final Random RANDOM = new Random(); - - private TestUtils() {} - - public static File constructTempDir(String dirPrefix) { - File file = new File(System.getProperty("java.io.tmpdir"), dirPrefix + RANDOM.nextInt(10000000)); - if (!file.mkdirs()) { - throw new RuntimeException("could not create temp directory: " + file.getAbsolutePath()); - } - file.deleteOnExit(); - return file; - } - - public static int getAvailablePort() { - try { - ServerSocket socket = new ServerSocket(0); - try { - return socket.getLocalPort(); - } finally { - socket.close(); - } - } catch (IOException e) { - throw new IllegalStateException("Cannot find available port: " + e.getMessage(), e); - } - } - - public static boolean deleteFile(File path) throws FileNotFoundException { - if (!path.exists()) { - throw new FileNotFoundException(path.getAbsolutePath()); - } - boolean ret = true; - if (path.isDirectory()) { - for (File f : path.listFiles()) { - ret = ret && deleteFile(f); - } - } - return ret && path.delete(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml b/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml deleted file mode 100644 index 94be439c3..000000000 --- a/nd4j/nd4j-serde/nd4j-camel-routes/pom.xml +++ /dev/null @@ -1,40 +0,0 @@ - - - - - nd4j-serde - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-camel-routes - pom - - nd4j-camel-routes - https://deeplearning4j.org - - nd4j-kafka - - - - - testresources - - - - diff --git a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java b/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java deleted file mode 100644 index 5b087a055..000000000 --- a/nd4j/nd4j-serde/nd4j-gson/src/main/java/org/nd4j/serde/gson/GsonDeserializationUtils.java +++ /dev/null @@ -1,103 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.serde.gson; - -import org.nd4j.shade.guava.primitives.Ints; -import org.nd4j.shade.guava.primitives.Longs; -import com.google.gson.JsonArray; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import org.apache.commons.lang3.StringUtils; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; - -import java.text.NumberFormat; -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; - -/** - * Gson serialization - * - * @author Alex Black - * @author Adam Gibson - */ -public class GsonDeserializationUtils { - private static final JsonParser JSON_PARSER = new JsonParser(); - - static { - NumberFormat format = NumberFormat.getIntegerInstance(); - format.setGroupingUsed(false); - } - - /** - * Deserialize an ndarray - * form json - * @param serializedRawArray - * @return - */ - public static INDArray deserializeRawJson(String serializedRawArray) { - - //String cleanedRawArray = serializedRawArray.replaceAll("(?<=[\\d])(,)(?=[\\d])", ""); - String cleanedRawArray = serializedRawArray; - JsonArray jsonArray = JSON_PARSER.parse(cleanedRawArray).getAsJsonArray(); - - List dimensions = new ArrayList<>(); - dimensions.add(jsonArray.size()); - getSizeMultiDimensionalArray(jsonArray, dimensions); - - return buildArray(dimensions, cleanedRawArray); - } - - /* - The below method works under the following assumption - which is an INDArray can not have a row such as [ 1 , 2, [3, 4] ] - and either all elements of an INDArray are either INDArrays themselves or scalars. - So if that is the case, then it suffices to only check the first element of each JsonArray - to see if that first element is itself an JsonArray. If it is an array, then we must check - the first element of that array to see if it's a scalar or array. - */ - - private static void getSizeMultiDimensionalArray(JsonArray jsonArray, List dimensions) { - Iterator iterator = jsonArray.iterator(); - - if (iterator.hasNext()) { - JsonElement jsonElement = iterator.next(); - if (jsonElement.isJsonArray()) { - JsonArray shapeArray = jsonElement.getAsJsonArray(); - dimensions.add(shapeArray.size()); - getSizeMultiDimensionalArray(shapeArray, dimensions); - } - } - } - - private static boolean isArrayWithSingleRow(List dimensions) { - return dimensions.size() == 1; - } - - private static INDArray buildArray(List dimensions, String rawArray) { - long[] shape = Longs.toArray(dimensions); - String[] entries = StringUtils.replacePattern(rawArray, "[\\[\\]\\n]", "").split(","); - double[] entryValues = new double[entries.length]; - - for (int i = 0; i < entries.length; i++) { - entryValues[i] = Double.parseDouble(entries[i]); - } - - return Nd4j.create(entryValues, shape, Nd4j.defaultFloatingPointType()); - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/pom.xml b/nd4j/nd4j-serde/nd4j-jackson/pom.xml deleted file mode 100644 index f1ef0aa79..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/pom.xml +++ /dev/null @@ -1,56 +0,0 @@ - - - - - - nd4j-serde - org.nd4j - 1.0.0-SNAPSHOT - - 4.0.0 - - nd4j-jackson - - - com.fasterxml.jackson.core - jackson-databind - ${jackson.version} - - provided - - - org.nd4j - nd4j-api - ${project.version} - provided - - - - - org.nd4j - jackson - ${project.version} - jar - - - - - - testresources - - - diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java deleted file mode 100644 index db4125d16..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorDeSerializer.java +++ /dev/null @@ -1,59 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson; - -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.VectorDeSerializer} - */ -public class VectorDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - JsonNode arr = node.get("dataBuffer"); - int rank = node.get("rankField").asInt(); - int numElements = node.get("numElements").asInt(); - int offset = node.get("offsetField").asInt(); - JsonNode shape = node.get("shapeField"); - JsonNode stride = node.get("strideField"); - int[] realShape = new int[rank]; - int[] realStride = new int[rank]; - DataBuffer buff = Nd4j.createBuffer(numElements); - for (int i = 0; i < numElements; i++) { - buff.put(i, arr.get(i).asDouble()); - } - - String ordering = node.get("orderingField").asText(); - for (int i = 0; i < rank; i++) { - realShape[i] = shape.get(i).asInt(); - realStride[i] = stride.get(i).asInt(); - } - - INDArray ret = Nd4j.create(buff, realShape, realStride, offset, ordering.charAt(0)); - return ret; - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java deleted file mode 100644 index 651e1980c..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/VectorSerializer.java +++ /dev/null @@ -1,64 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson; - - -import org.nd4j.linalg.api.buffer.DataBuffer; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.VectorSerializer} - */ -public class VectorSerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - if (indArray.isView()) - indArray = indArray.dup(indArray.ordering()); - jsonGenerator.writeStartObject(); - DataBuffer view = indArray.data(); - jsonGenerator.writeArrayFieldStart("dataBuffer"); - for (int i = 0; i < view.length(); i++) { - jsonGenerator.writeNumber(view.getDouble(i)); - } - - jsonGenerator.writeEndArray(); - - jsonGenerator.writeArrayFieldStart("shapeField"); - for (int i = 0; i < indArray.rank(); i++) { - jsonGenerator.writeNumber(indArray.size(i)); - } - jsonGenerator.writeEndArray(); - - jsonGenerator.writeArrayFieldStart("strideField"); - for (int i = 0; i < indArray.rank(); i++) - jsonGenerator.writeNumber(indArray.stride(i)); - jsonGenerator.writeEndArray(); - - jsonGenerator.writeNumberField("offsetField", indArray.offset()); - jsonGenerator.writeNumberField("rankField", indArray.rank()); - jsonGenerator.writeNumberField("numElements", view.length()); - jsonGenerator.writeStringField("orderingField", String.valueOf(indArray.ordering())); - jsonGenerator.writeEndObject(); - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java deleted file mode 100644 index 95560fbb6..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArrayDeSerializer.java +++ /dev/null @@ -1,40 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.ndarray; - -import com.fasterxml.jackson.core.JsonParser; -import com.fasterxml.jackson.databind.DeserializationContext; -import com.fasterxml.jackson.databind.JsonDeserializer; -import com.fasterxml.jackson.databind.JsonNode; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.ndarray.NDArrayDeSerializer} - */ -public class NDArrayDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - String field = node.get("array").asText(); - INDArray ret = Nd4jBase64.fromBase64(field); - return ret; - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java deleted file mode 100644 index abb7fc209..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/ndarray/NDArraySerializer.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.ndarray; - - -import com.fasterxml.jackson.core.JsonGenerator; -import com.fasterxml.jackson.databind.JsonSerializer; -import com.fasterxml.jackson.databind.SerializerProvider; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.ndarray.NDArraySerializer} - */ -public class NDArraySerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - String toBase64 = Nd4jBase64.base64String(indArray); - jsonGenerator.writeStartObject(); - jsonGenerator.writeStringField("array", toBase64); - jsonGenerator.writeEndObject(); - - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java deleted file mode 100644 index e47a241dd..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArrayDeSerializer.java +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.shaded; - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonParser; -import org.nd4j.shade.jackson.databind.DeserializationContext; -import org.nd4j.shade.jackson.databind.JsonDeserializer; -import org.nd4j.shade.jackson.databind.JsonNode; - -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.shaded.NDArrayDeSerializer} - */ -@Deprecated -public class NDArrayDeSerializer extends JsonDeserializer { - @Override - public INDArray deserialize(JsonParser jp, DeserializationContext deserializationContext) throws IOException { - JsonNode node = jp.getCodec().readTree(jp); - String field = node.get("array").asText(); - INDArray ret = Nd4jBase64.fromBase64(field.toString()); - return ret; - - } -} diff --git a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java b/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java deleted file mode 100644 index 6667e6685..000000000 --- a/nd4j/nd4j-serde/nd4j-jackson/src/main/java/org/nd4j/shade/serde/jackson/shaded/NDArraySerializer.java +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.shade.serde.jackson.shaded; - - -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.factory.Nd4j; -import org.nd4j.serde.base64.Nd4jBase64; -import org.nd4j.shade.jackson.core.JsonGenerator; -import org.nd4j.shade.jackson.databind.JsonSerializer; -import org.nd4j.shade.jackson.databind.SerializerProvider; - -import java.io.ByteArrayOutputStream; -import java.io.IOException; - -/** - * @author Adam Gibson - * @deprecated Use {@link org.nd4j.serde.jackson.shaded.NDArraySerializer} - */ -@Deprecated -public class NDArraySerializer extends JsonSerializer { - @Override - public void serialize(INDArray indArray, JsonGenerator jsonGenerator, SerializerProvider serializerProvider) - throws IOException { - String toBase64 = Nd4jBase64.base64String(indArray); - jsonGenerator.writeStartObject(); - jsonGenerator.writeStringField("array", toBase64); - jsonGenerator.writeEndObject(); - } -} diff --git a/nd4j/nd4j-serde/pom.xml b/nd4j/nd4j-serde/pom.xml index d4fc4ff05..aa6f9bed1 100644 --- a/nd4j/nd4j-serde/pom.xml +++ b/nd4j/nd4j-serde/pom.xml @@ -28,10 +28,7 @@ pom nd4j-aeron - nd4j-jackson nd4j-kryo - nd4j-camel-routes - nd4j-gson nd4j-arrow diff --git a/nd4j/nd4j-uberjar/pom.xml b/nd4j/nd4j-uberjar/pom.xml index c3398dea9..84f1c0d4a 100644 --- a/nd4j/nd4j-uberjar/pom.xml +++ b/nd4j/nd4j-uberjar/pom.xml @@ -205,16 +205,6 @@ nd4j-common ${project.version} - - org.nd4j - nd4j-buffer - ${project.version} - - - org.nd4j - nd4j-context - ${project.version} - org.nd4j nd4j-api diff --git a/nd4j/pom.xml b/nd4j/pom.xml index ad87566df..d8e4a58f2 100644 --- a/nd4j/pom.xml +++ b/nd4j/pom.xml @@ -56,8 +56,6 @@ nd4j-jdbc nd4j-serde nd4j-common - nd4j-buffer - nd4j-context nd4j-backends nd4j-parameter-server-parent nd4j-uberjar