Cleanup of multiple projects (#175)

* Cleanup modules

* Moving subprojects to nd4j-api

* Project cleanup

* Dropped AWS sub-project

* dl4j-util moved to core

* dl4j-perf moved to core

* Tests coverage

* Revert "Moving subprojects to nd4j-api"

This reverts commit bc6eb573c6b60c407ade47172c5d204725077e6b.

* Moved nd4j-buffer and nd4j-context to nd4j-api

* Rolled back change

* Revert "Project cleanup"

This reverts commit 64ac7f369b2d968f7be437718034f093fc886ffc.

* Datavec cleaned up

* Revert "Moved nd4j-buffer and nd4j-context to nd4j-api"

This reverts commit 75f4e8da80d2551e44e1251dd6c5923289fff8e1.

# Conflicts:
#	nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java

* Resolve conflict

* Compilation fixed.

* nd4j-context and nd4j-buffer moved to nd4j-api

* Fixed TF mapping for mmul

* Fix for dl4j-cuda tests

Signed-off-by: Alex Black <blacka101@gmail.com>

* Move last few tests from deeplearning4j-nn to -core

Signed-off-by: Alex Black <blacka101@gmail.com>

* Remove incorrect TF import mapping for TensorMmul op

Signed-off-by: Alex Black <blacka101@gmail.com>

* Cleaned TF mapping

* Fix path for test results on windows

* Remove old dependency

Signed-off-by: Alex Black <blacka101@gmail.com>

* One more attempt to fix path for test results on windows

* fixup! One more attempt to fix path for test results on windows

* fixup! One more attempt to fix path for test results on windows

Co-authored-by: Alex Black <blacka101@gmail.com>
Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com>
Co-authored-by: raver119 <raver119@gmail.com>
master
Alexander Stoyakin 2020-01-24 21:35:00 +02:00 committed by raver119
parent 99a54829c2
commit 4db28a9300
156 changed files with 92 additions and 5299 deletions

View File

@ -1,116 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>datavec-camel</artifactId>
<name>DataVec Camel Component</name>
<url>http://deeplearning4j.org</url>
<dependencies>
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-csv</artifactId>
<version>${camel.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-core</artifactId>
<version>${camel.version}</version>
</dependency>
<!-- support camel documentation -->
<!-- logging -->
<!-- testing -->
<dependency>
<groupId>org.apache.camel</groupId>
<artifactId>camel-test</artifactId>
<version>${camel.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<defaultGoal>install</defaultGoal>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.7</source>
<target>1.7</target>
</configuration>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.1</version>
<configuration>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<!-- generate components meta-data and validate component includes documentation etc -->
<plugin>
<groupId>org.apache.camel</groupId>
<artifactId>camel-package-maven-plugin</artifactId>
<version>${camel.version}</version>
<executions>
<execution>
<id>prepare</id>
<goals>
<goal>prepare-components</goal>
</goals>
<phase>generate-resources</phase>
</execution>
<!-- <execution>
<id>validate</id>
<goals>
<goal>validate-components</goal>
</goals>
<phase>prepare-package</phase>
</execution>-->
</executions>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,45 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.CamelContext;
import org.apache.camel.Endpoint;
import org.apache.camel.impl.UriEndpointComponent;
import java.util.Map;
/**
* Represents the component that manages {@link DataVecEndpoint}.
*/
public class DataVecComponent extends UriEndpointComponent {
public DataVecComponent() {
super(DataVecEndpoint.class);
}
public DataVecComponent(CamelContext context) {
super(context, DataVecEndpoint.class);
}
@Override
protected Endpoint createEndpoint(String uri, String remaining, Map<String, Object> parameters) throws Exception {
DataVecEndpoint endpoint = new DataVecEndpoint(uri, this);
setProperties(endpoint, parameters);
endpoint.setInputFormat(remaining);
return endpoint;
}
}

View File

@ -1,93 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.apache.camel.Processor;
import org.apache.camel.impl.ScheduledPollConsumer;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.InputFormat;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
/**
* The DataVec consumer.
* @author Adam Gibson
*/
public class DataVecConsumer extends ScheduledPollConsumer {
private final DataVecEndpoint endpoint;
private Class<? extends InputFormat> inputFormatClazz;
private Class<? extends DataVecMarshaller> marshallerClazz;
private InputFormat inputFormat;
private Configuration configuration;
private DataVecMarshaller marshaller;
public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) {
super(endpoint, processor);
this.endpoint = endpoint;
try {
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
inputFormat = inputFormatClazz.newInstance();
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
marshaller = marshallerClazz.newInstance();
configuration = new Configuration();
for (String prop : endpoint.getConsumerProperties().keySet())
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
protected InputSplit inputFromExchange(Exchange exchange) {
return marshaller.getSplit(exchange);
}
@Override
protected int poll() throws Exception {
Exchange exchange = endpoint.createExchange();
InputSplit split = inputFromExchange(exchange);
RecordReader reader = inputFormat.createReader(split, configuration);
int numMessagesPolled = 0;
while (reader.hasNext()) {
// create a message body
while (reader.hasNext()) {
exchange.getIn().setBody(reader.next());
try {
// send message to next processor in the route
getProcessor().process(exchange);
numMessagesPolled++; // number of messages polled
} finally {
// log exception if an exception occurred and was not handled
if (exchange.getException() != null) {
getExceptionHandler().handleException("Error processing exchange", exchange,
exchange.getException());
}
}
}
}
return numMessagesPolled;
}
}

View File

@ -1,68 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import lombok.Data;
import org.apache.camel.Consumer;
import org.apache.camel.Processor;
import org.apache.camel.Producer;
import org.apache.camel.impl.DefaultEndpoint;
import org.apache.camel.spi.Metadata;
import org.apache.camel.spi.UriEndpoint;
import org.apache.camel.spi.UriParam;
import org.apache.camel.spi.UriPath;
/**
* Represents a DataVec endpoint.
* @author Adam Gibson
*/
@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?",
consumerClass = DataVecConsumer.class, label = "datavec")
@Data
public class DataVecEndpoint extends DefaultEndpoint {
@UriPath
@Metadata(required = "true")
private String inputFormat;
@UriParam(defaultValue = "")
private String outputFormat;
@UriParam
@Metadata(required = "true")
private String inputMarshaller;
@UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter")
private String writableConverter;
public DataVecEndpoint(String uri, DataVecComponent component) {
super(uri, component);
}
public DataVecEndpoint(String endpointUri) {
super(endpointUri);
}
public Producer createProducer() throws Exception {
return new DataVecProducer(this);
}
public Consumer createConsumer(Processor processor) throws Exception {
return new DataVecConsumer(this, processor);
}
public boolean isSingleton() {
return true;
}
}

View File

@ -1,36 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.datavec.api.split.InputSplit;
/**
* Marshals na exchange in to an input split
* @author Adam Gibson
*/
public interface DataVecMarshaller {
/**
*
* @param exchange
* @return
*/
InputSplit getSplit(Exchange exchange);
}

View File

@ -1,109 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.apache.camel.impl.DefaultProducer;
import org.datavec.api.conf.Configuration;
import org.datavec.api.formats.input.InputFormat;
import org.datavec.api.io.WritableConverter;
import org.datavec.api.io.converters.SelfWritableConverter;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputSplit;
import org.datavec.api.writable.Writable;
import java.util.ArrayList;
import java.util.Collection;
/**
* The DataVec producer.
* Converts input records in to their final form
* based on the input split generated from
* the given exchange.
*
* @author Adam Gibson
*/
public class DataVecProducer extends DefaultProducer {
private Class<? extends InputFormat> inputFormatClazz;
private Class<? extends DataVecMarshaller> marshallerClazz;
private InputFormat inputFormat;
private Configuration configuration;
private WritableConverter writableConverter;
private DataVecMarshaller marshaller;
public DataVecProducer(DataVecEndpoint endpoint) {
super(endpoint);
if (endpoint.getInputFormat() != null) {
try {
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
inputFormat = inputFormatClazz.newInstance();
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
Class<? extends WritableConverter> converterClazz =
(Class<? extends WritableConverter>) Class.forName(endpoint.getWritableConverter());
writableConverter = converterClazz.newInstance();
marshaller = marshallerClazz.newInstance();
configuration = new Configuration();
for (String prop : endpoint.getConsumerProperties().keySet())
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
} catch (Exception e) {
throw new RuntimeException(e);
}
}
}
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
protected InputSplit inputFromExchange(Exchange exchange) {
return marshaller.getSplit(exchange);
}
@Override
public void process(Exchange exchange) throws Exception {
InputSplit split = inputFromExchange(exchange);
RecordReader reader = inputFormat.createReader(split, configuration);
Collection<Collection<Writable>> newRecord = new ArrayList<>();
if (!(writableConverter instanceof SelfWritableConverter)) {
newRecord = new ArrayList<>();
while (reader.hasNext()) {
Collection<Writable> newRecordAdd = new ArrayList<>();
// create a message body
Collection<Writable> next = reader.next();
for (Writable writable : next) {
newRecordAdd.add(writableConverter.convert(writable));
}
newRecord.add(newRecordAdd);
}
} else {
while (reader.hasNext()) {
// create a message body
Collection<Writable> next = reader.next();
newRecord.add(next);
}
}
exchange.getIn().setBody(newRecord);
exchange.getOut().setBody(newRecord);
}
}

View File

@ -1,42 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component.csv.marshaller;
import org.apache.camel.Exchange;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.ListStringSplit;
import org.datavec.camel.component.DataVecMarshaller;
import java.util.List;
/**
* Marshals List<List<String>>
*
* @author Adam Gibson
*/
public class ListStringInputMarshaller implements DataVecMarshaller {
/**
* @param exchange
* @return
*/
@Override
public InputSplit getSplit(Exchange exchange) {
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
InputSplit listSplit = new ListStringSplit(data);
return listSplit;
}
}

View File

@ -1 +0,0 @@
class=org.datavec.camel.component.DataVecComponent

View File

@ -1,82 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.builder.RouteBuilder;
import org.apache.camel.component.mock.MockEndpoint;
import org.apache.camel.test.junit4.CamelTestSupport;
import org.apache.commons.io.FileUtils;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.writable.Writable;
import org.junit.BeforeClass;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.nd4j.linalg.io.ClassPathResource;
import java.io.File;
import java.util.ArrayList;
import java.util.Collection;
public class DataVecComponentTest extends CamelTestSupport {
@ClassRule
public static TemporaryFolder testDir = new TemporaryFolder();
private static File dir;
private static File irisFile;
@BeforeClass
public static void before() throws Exception {
dir = testDir.newFolder();
File iris = new ClassPathResource("iris.dat").getFile();
irisFile = new File(dir, "iris.dat");
FileUtils.copyFile(iris, irisFile );
}
@Test
public void testDataVec() throws Exception {
MockEndpoint mock = getMockEndpoint("mock:result");
//1
mock.expectedMessageCount(1);
RecordReader reader = new CSVRecordReader();
reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
Collection<Collection<Writable>> recordAssertion = new ArrayList<>();
while (reader.hasNext())
recordAssertion.add(reader.next());
mock.expectedBodiesReceived(recordAssertion);
assertMockEndpointsSatisfied();
}
@Override
protected RouteBuilder createRouteBuilder() throws Exception {
return new RouteBuilder() {
public void configure() {
from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv()
.to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter")
.to("mock:result");
}
};
}
}

View File

@ -1,41 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.camel.component;
import org.apache.camel.Exchange;
import org.datavec.api.split.InputSplit;
import org.datavec.api.split.ListStringSplit;
import java.util.List;
/**
* Marshals List<List<String>>
*
* @author Adam Gibson
*/
public class ListStringInputMarshaller implements DataVecMarshaller {
/**
* @param exchange
* @return
*/
@Override
public InputSplit getSplit(Exchange exchange) {
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
InputSplit listSplit = new ListStringSplit(data);
return listSplit;
}
}

View File

@ -1,30 +0,0 @@
################################################################################
# Copyright (c) 2015-2018 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
#
# The logging properties used
#
log4j.rootLogger=INFO, out
# uncomment the following line to turn on Camel debugging
#log4j.logger.org.apache.camel=DEBUG
# CONSOLE appender not used by default
log4j.appender.out=org.apache.log4j.ConsoleAppender
log4j.appender.out.layout=org.apache.log4j.PatternLayout
log4j.appender.out.layout.ConversionPattern=[%30.30t] %-30.30c{1} %-5p %m%n
#log4j.appender.out.layout.ConversionPattern=%d [%-15.15t] %-5p %-30.30c{1} - %m%n

View File

@ -37,11 +37,6 @@
<version>${logback.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-buffer</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>com.github.jai-imageio</groupId>
<artifactId>jai-imageio-core</artifactId>

View File

@ -1,65 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>datavec-parent</artifactId>
<groupId>org.datavec</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>datavec-perf</artifactId>
<name>datavec-perf</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
<version>${slf4j.version}</version>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-data-image</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.datavec</groupId>
<artifactId>datavec-api</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,112 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.perf.timing;
import lombok.val;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.split.InputStreamInputSplit;
import org.datavec.api.writable.Writable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
import org.nd4j.linalg.memory.MemcpyDirection;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
/**
* Timing components of a data vec pipeline
* consisting of:
* {@link RecordReader}, {@link InputStreamInputSplit}
* (note that this uses input stream input split,
* the record reader must support {@link InputStreamInputSplit} for this to work)
*
* @author Adam Gibson
*/
public class IOTiming {
/**
* Returns statistics for components of a datavec pipeline
* averaged over the specified number of times
* @param nTimes the number of times to run the pipeline for averaging
* @param recordReader the record reader
* @param file the file to read
* @param function the function
* @return the averaged {@link TimingStatistics} for input/output on a record
* reader and ndarray creation (based on the given function
* @throws Exception
*/
public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception {
TimingStatistics timingStatistics = null;
for(int i = 0; i < nTimes; i++) {
TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function);
if(timingStatistics == null)
timingStatistics = timingStatistics1;
else {
timingStatistics = timingStatistics.add(timingStatistics1);
}
}
return timingStatistics.average(nTimes);
}
/**
*
* @param reader
* @param inputStream
* @param function
* @return
* @throws Exception
*/
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
InputStream inputStream,
INDArrayCreationFunction function) throws Exception {
reader.initialize(new InputStreamInputSplit(inputStream));
long longNanos = System.nanoTime();
List<Writable> next = reader.next();
long endNanos = System.nanoTime();
long etlDiff = endNanos - longNanos;
long startArrCreation = System.nanoTime();
INDArray arr = function.createFromRecord(next);
long endArrCreation = System.nanoTime();
long endCreationDiff = endArrCreation - startArrCreation;
Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
return TimingStatistics.builder()
.diskReadingTimeNanos(etlDiff)
.bandwidthNanosHostToDevice(bw)
.bandwidthDeviceToHost(deviceToHost)
.ndarrayCreationTimeNanos(endCreationDiff)
.build();
}
public interface INDArrayCreationFunction {
INDArray createFromRecord(List<Writable> record);
}
}

View File

@ -1,74 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.perf.timing;
import lombok.Builder;
import lombok.Data;
/**
* The timing statistics for a data pipeline including:
* ndarray creation time in nanoseconds
* disk reading time in nanoseconds
* bandwidth used in host to device in nano seconds
* bandwidth device to host in nanoseconds
*
* @author Adam Gibson
*/
@Builder
@Data
public class TimingStatistics {
private long ndarrayCreationTimeNanos;
private long diskReadingTimeNanos;
private long bandwidthNanosHostToDevice;
private long bandwidthDeviceToHost;
/**
* Accumulate the given statistics
* @param timingStatistics the statistics to add
* @return the added statistics
*/
public TimingStatistics add(TimingStatistics timingStatistics) {
return TimingStatistics.builder()
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos)
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice)
.diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos)
.bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost)
.build();
}
/**
* Average the results relative to the number of n.
* This method is meant to be used alongside
* {@link #add(TimingStatistics)}
* accumulated a number of times
* @param n n the number of elements
* @return the averaged results
*/
public TimingStatistics average(long n) {
return TimingStatistics.builder()
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n)
.bandwidthDeviceToHost(bandwidthDeviceToHost / n)
.diskReadingTimeNanos(diskReadingTimeNanos / n)
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n)
.build();
}
}

View File

@ -1,60 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.datavec.datavec.timing;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.Writable;
import org.datavec.image.loader.NativeImageLoader;
import org.datavec.image.recordreader.ImageRecordReader;
import org.datavec.perf.timing.IOTiming;
import org.datavec.perf.timing.TimingStatistics;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.io.ClassPathResource;
import java.util.List;
public class IOTimingTest {
@Test
public void testTiming() throws Exception {
final RecordReader image = new ImageRecordReader(28,28);
final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28);
TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() {
@Override
public INDArray createFromRecord(List<Writable> record) {
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
return imageWritable.get();
}
});
System.out.println(timingStatistics);
TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() {
@Override
public INDArray createFromRecord(List<Writable> record) {
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
return imageWritable.get();
}
});
System.out.println(timingStatistics1);
}
}

View File

@ -59,16 +59,12 @@
<modules>
<module>datavec-api</module>
<module>datavec-data</module>
<module>datavec-geo</module>
<module>datavec-hadoop</module>
<module>datavec-spark</module>
<module>datavec-camel</module>
<module>datavec-local</module>
<module>datavec-spark-inference-parent</module>
<module>datavec-jdbc</module>
<module>datavec-excel</module>
<module>datavec-arrow</module>
<module>datavec-perf</module>
<module>datavec-python</module>
</modules>

View File

@ -153,6 +153,17 @@
<version>${jaxb.version}</version>
<scope>provided</scope>
</dependency>
<!-- Dependencies for dl4j-perf subproject -->
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-json</artifactId>
<version>${oshi.version}</version>
</dependency>
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-core</artifactId>
<version>${oshi.version}</version>
</dependency>
</dependencies>
<profiles>

View File

@ -14,11 +14,13 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.solvers.accumulation;
package org.deeplearning4j.optimize.solver.accumulation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
import org.junit.Test;
import org.nd4j.linalg.api.ndarray.INDArray;
@ -68,8 +70,8 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
accumulator.receiveUpdate(encoded);
// just purge updates, like they were consumed
for (int i = 0; i < accumulator.messages.size(); i++) {
accumulator.messages.get(i).clear();
for (int i = 0; i < accumulator.getMessages().size(); i++) {
accumulator.getMessages().get(i).clear();
}
}
}

View File

@ -14,12 +14,13 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.solvers.accumulation;
package org.deeplearning4j.optimize.solver.accumulation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
import org.junit.Ignore;
import org.junit.Test;
import org.nd4j.linalg.factory.Nd4j;
@ -97,9 +98,9 @@ public class IndexedTailTest extends BaseDL4JTest {
assertEquals(-1, tail.maxAppliedIndexEverywhere());
// 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements
tail.positions.get(11L).set(2);
tail.positions.get(22L).set(2);
tail.positions.get(33L).set(3);
tail.getPositions().get(11L).set(2);
tail.getPositions().get(22L).set(2);
tail.getPositions().get(33L).set(3);
// all elements including this index are safe to remove, because they were consumed everywhere
assertEquals(2, tail.maxAppliedIndexEverywhere());
@ -197,10 +198,10 @@ public class IndexedTailTest extends BaseDL4JTest {
Nd4j.getExecutioner().commit();
}
assertTrue(tail.collapsedMode.get());
assertTrue(tail.getCollapsedMode().get());
assertEquals(1, tail.updatesSize());
val array = tail.updates.get(32L);
val array = tail.getUpdates().get(32L);
assertNotNull(array);
assertEquals(sum, (int) array.getDouble(0));
}

View File

@ -14,12 +14,13 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.solvers.accumulation;
package org.deeplearning4j.optimize.solver.accumulation;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.lang3.RandomUtils;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
import org.deeplearning4j.util.ThreadUtils;
import org.junit.Ignore;
import org.junit.Test;

View File

@ -14,7 +14,7 @@
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.solvers.accumulation;
package org.deeplearning4j.optimize.solver.accumulation;
import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;

View File

@ -1,5 +1,6 @@
package org.deeplearning4j.optimize.listeners;
package org.deeplearning4j.optimizer.listener;
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
import org.junit.Ignore;
import org.junit.Test;

View File

@ -305,7 +305,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams);
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, false, -1, excludeParams, null);
assertTrue(gradOK);
}
@ -417,7 +417,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
}
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32);
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32, null, null);
assertTrue(gradOK);
}
@ -655,9 +655,12 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
};
log.info("*** Starting test: " + msg + " ***");
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null,
false, -1, null, c);
boolean gradOK = GradientCheckUtil.checkGradients(
new GradientCheckUtil.MLNConfig().net(mln).epsilon(DEFAULT_EPS)
.maxRelError(DEFAULT_MAX_REL_ERROR).minAbsoluteError(DEFAULT_MIN_ABS_ERROR)
.print(PRINT_RESULTS ? GradientCheckUtil.PrintMode.ZEROS : GradientCheckUtil.PrintMode.FAILURES_ONLY)
.exitOnFirstError(RETURN_ON_FIRST_FAILURE)
.input(f).labels(l).callEachIter(c));
assertTrue(msg, gradOK);
TestUtils.testModelSerialization(mln);
@ -691,7 +694,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, excludeParams);
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, false, -1, excludeParams, null);
assertTrue(gradOK);

View File

@ -63,6 +63,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
protected int parties;
@Getter
protected MessageHandler handler;
@Getter
protected List<BlockingQueue<INDArray>> messages = new ArrayList<>();
protected List<MemoryWorkspace> workspaces = new ArrayList<>();
protected List<ReentrantLock> locks = new ArrayList<>();
@ -106,7 +107,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode);
}
protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
int queueSize, Double boundary, boolean encodingDebugMode) {
this.parties = parties;
this.handler = handler;

View File

@ -16,6 +16,7 @@
package org.deeplearning4j.optimize.solvers.accumulation;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
@ -44,9 +45,11 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
@Slf4j
public class IndexedTail {
// here we store positions of individual consumers
@Getter
protected ConcurrentHashMap<Long, AtomicLong> positions = new ConcurrentHashMap<>();
// here we store individual updates
@Getter
protected Map<Long, INDArray> updates = new ConcurrentHashMap<>();
// simple counter for new updates
@ -67,6 +70,7 @@ public class IndexedTail {
protected final boolean allowCollapse;
protected final long[] shape;
protected final int collapseThreshold = 32;
@Getter
protected AtomicBoolean collapsedMode = new AtomicBoolean(false);
protected AtomicLong collapsedIndex = new AtomicLong(-1);
@ -148,7 +152,7 @@ public class IndexedTail {
}
}
protected long firstNotAppliedIndexEverywhere() {
public long firstNotAppliedIndexEverywhere() {
long maxIdx = -1;
// if there's no updates posted yet - just return negative value
@ -163,7 +167,7 @@ public class IndexedTail {
return maxIdx + 1;
}
protected long maxAppliedIndexEverywhere() {
public long maxAppliedIndexEverywhere() {
long maxIdx = Long.MAX_VALUE;
for (val v:positions.values()) {
if (v.get() < maxIdx)
@ -212,7 +216,7 @@ public class IndexedTail {
return getDelta(Thread.currentThread().getId());
}
protected long getDelta(long threadId) {
public long getDelta(long threadId) {
return getGlobalPosition() - getLocalPosition(threadId);
}
@ -293,7 +297,7 @@ public class IndexedTail {
/**
* This method does maintenance of updates within
*/
protected synchronized void maintenance() {
public synchronized void maintenance() {
// first of all we're checking, if all consumers were already registered. if not - just no-op.
if (positions.size() < expectedConsumers) {
log.trace("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", positions.size(), expectedConsumers);
@ -326,7 +330,7 @@ public class IndexedTail {
* This method returns actual number of updates stored within tail
* @return
*/
protected int updatesSize() {
public int updatesSize() {
return updates.size();
}
@ -348,11 +352,11 @@ public class IndexedTail {
return result;
}
protected boolean isDead() {
public boolean isDead() {
return dead.get();
}
protected void notifyDead() {
public void notifyDead() {
dead.set(true);
}

View File

@ -1,158 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j;
import org.apache.commons.compress.utils.IOUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import java.io.*;
import java.util.Random;
import static org.junit.Assert.assertEquals;
public class TestUtils {
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
MultiLayerNetwork restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
assertEquals(net.params(), restored.params());
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
serializeDeserializeJava(conf);
return restored;
}
public static ComputationGraph testModelSerialization(ComputationGraph net){
ComputationGraph restored;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
ModelSerializer.writeModel(net, baos, true);
byte[] bytes = baos.toByteArray();
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
restored = ModelSerializer.restoreComputationGraph(bais, true);
assertEquals(net.getConfiguration(), restored.getConfiguration());
assertEquals(net.params(), restored.params());
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
ComputationGraphConfiguration conf = net.getConfiguration();
serializeDeserializeJava(conf);
return restored;
}
private static <T> T serializeDeserializeJava(T object){
byte[] bytes;
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
oos.writeObject(object);
oos.close();
bytes = baos.toByteArray();
} catch (IOException e){
//Should never happen
throw new RuntimeException(e);
}
T out;
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
out = (T)ois.readObject();
} catch (IOException | ClassNotFoundException e){
throw new RuntimeException(e);
}
assertEquals(object, out);
return out;
}
public static INDArray randomOneHot(long examples, long nOut){
return randomOneHot(examples, nOut, new Random(12345));
}
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
return randomOneHot(examples, nOut, new Random(rngSeed));
}
public static INDArray randomOneHot(long examples, long nOut, Random rng){
INDArray arr = Nd4j.create(examples, nOut);
for( int i=0; i<examples; i++ ){
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
}
return arr;
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength){
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random());
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, long rngSeed){
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
}
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
for( int i=0; i<minibatch; i++ ){
for( int j=0; j<tsLength; j++ ){
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
}
}
return out;
}
public static INDArray randomBernoulli(int... shape) {
return randomBernoulli(0.5, shape);
}
public static INDArray randomBernoulli(double p, int... shape){
INDArray ret = Nd4j.createUninitialized(shape);
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
return ret;
}
public static void writeStreamToFile(File out, InputStream is) throws IOException {
byte[] b = IOUtils.toByteArray(is);
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
os.write(b);
}
}
}

View File

@ -1,99 +0,0 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-scaleout</artifactId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<artifactId>deeplearning4j-aws_2.11</artifactId>
<name>DeepLearning4j-AWS</name>
<version>1.0.0-SNAPSHOT</version>
<build>
<plugins>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<properties>
<!-- Default scala versions, may be overwritten by build profiles -->
<scala.version>2.11.12</scala.version>
<scala.binary.version>2.11</scala.binary.version>
</properties>
<dependencies>
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-java-sdk</artifactId>
<version>1.11.24</version>
</dependency>
<dependency>
<groupId>args4j</groupId>
<artifactId>args4j</artifactId>
<version>2.32</version>
</dependency>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-util</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.jcraft</groupId>
<artifactId>jsch</artifactId>
<version>${jsch.version}</version>
</dependency>
<dependency>
<groupId>org.threadly</groupId>
<artifactId>threadly</artifactId>
<version>${threadly.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-lang3</artifactId>
<version>${commons-lang3.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,34 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.dataset;
import org.deeplearning4j.aws.s3.reader.S3Downloader;
import java.io.InputStream;
public class DataSetLoader {
private String bucket;
public void onData(InputStream is) {
S3Downloader downloader = new S3Downloader();
}
}

View File

@ -1,222 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.ec2;
import com.amazonaws.regions.Regions;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.model.*;
import org.deeplearning4j.aws.s3.BaseS3;
import org.deeplearning4j.util.ThreadUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* Creates Ec2Boxes
* @author Adam Gibson
*
*/
public class Ec2BoxCreator extends BaseS3 {
private String amiId;
private int numBoxes;
private String size;
private List<String> boxesCreated;
private String securityGroupId;
private String keyPair;
private Regions regions = Regions.DEFAULT_REGION;
private static final Logger log = LoggerFactory.getLogger(Ec2BoxCreator.class);
//centos
public final static String DEFAULT_AMI = "ami-8997afe0";
/**
*
* @param numBoxes number of boxes
* @param size the size of the instances
*/
public Ec2BoxCreator(int numBoxes, String size, String securityGroupId, String keyPair) {
this(DEFAULT_AMI, numBoxes, size, securityGroupId, keyPair);
}
/**
*
* @param amiId amazon image id
* @param numBoxes number of boxes
* @param size the size of the instances
* @param securityGroupId
*/
public Ec2BoxCreator(String amiId, int numBoxes, String size, String securityGroupId, String keyPair) {
super();
this.amiId = amiId;
this.numBoxes = numBoxes;
this.size = size;
this.keyPair = keyPair;
this.securityGroupId = securityGroupId;
}
public void createSpot() {
// Initializes a Spot Instance Request
RequestSpotInstancesRequest requestRequest = new RequestSpotInstancesRequest();
// Request 1 x t1.micro instance with a bid price of $0.03.
requestRequest.setSpotPrice("0.03");
requestRequest.setInstanceCount(Integer.valueOf(1));
// Setup the specifications of the launch. This includes the
// instance type (e.g. t1.micro) and the latest Amazon Linux
// AMI id available. Note, you should always use the latest
// Amazon Linux AMI id or another of your choosing.
LaunchSpecification launchSpecification = new LaunchSpecification();
launchSpecification.setImageId("ami-8c1fece5");
launchSpecification.setInstanceType("t1.micro");
// Add the security group to the request.
List<String> securityGroups = new ArrayList<>();
securityGroups.add("GettingStartedGroup");
launchSpecification.setSecurityGroups(securityGroups);
// Add the launch specifications to the request.
requestRequest.setLaunchSpecification(launchSpecification);
// Call the RequestSpotInstance API.
RequestSpotInstancesResult requestResult = getEc2().requestSpotInstances(requestRequest);
List<SpotInstanceRequest> requestResponses = requestResult.getSpotInstanceRequests();
// Setup an arraylist to collect all of the request ids we want to
// watch hit the running state.
List<String> spotInstanceRequestIds = new ArrayList<>();
// Add all of the request ids to the hashset, so we can determine when they hit the
// active state.
for (SpotInstanceRequest requestResponse : requestResponses) {
System.out.println("Created Spot Request: " + requestResponse.getSpotInstanceRequestId());
spotInstanceRequestIds.add(requestResponse.getSpotInstanceRequestId());
}
}
public void setRegion(Regions regions) {
this.regions = regions;
}
/**
* Create the instances
*/
public void create() {
RunInstancesRequest runInstancesRequest =
new RunInstancesRequest().withImageId(amiId).withInstanceType(size).withKeyName(keyPair)
.withMinCount(1).withSecurityGroupIds(securityGroupId).withMaxCount(numBoxes);
AmazonEC2 ec2 = getEc2();
ec2.setRegion(com.amazonaws.regions.Region.getRegion(regions));
List<Instance> boxes = ec2.runInstances(runInstancesRequest).getReservation().getInstances();
if (boxesCreated == null) {
boxesCreated = new ArrayList<>();
for (Instance i : boxes)
boxesCreated.add(i.getInstanceId());
log.info("Boxes created " + boxesCreated);
} else {
blowupBoxes();
boxesCreated.clear();
for (Instance i : boxes)
boxesCreated.add(i.getInstanceId());
}
}
public List<InstanceStateChange> blowupBoxes() {
TerminateInstancesRequest request = new TerminateInstancesRequest().withInstanceIds(boxesCreated);
if (boxesCreated != null) {
TerminateInstancesResult result = getEc2().terminateInstances(request);
List<InstanceStateChange> change = result.getTerminatingInstances();
log.info("Boxes destroyed " + boxesCreated);
return change;
}
return Collections.emptyList();
}
public void blockTillAllRunning() {
while (!allRunning()) {
ThreadUtils.uncheckedSleep(1000);
log.info("Not all created...");
}
}
public boolean allRunning() {
if (boxesCreated == null)
return false;
else {
DescribeInstancesResult result = getEc2().describeInstances();
List<Reservation> reservations = result.getReservations();
for (Reservation r : reservations) {
List<Instance> instances = r.getInstances();
for (Instance instance : instances) {
InstanceState state = instance.getState();
if (state.getCode() == 48)
continue;
if (state.getCode() != 16)
return false;
}
}
return true;
}
}
public List<String> getHosts() {
DescribeInstancesResult result = getEc2().describeInstances();
List<String> hosts = new ArrayList<>();
List<Reservation> reservations = result.getReservations();
for (Reservation r : reservations) {
List<Instance> instances = r.getInstances();
for (Instance instance : instances) {
InstanceState state = instance.getState();
if (state.getCode() == 48)
continue;
hosts.add(instance.getPublicDnsName());
}
}
return hosts;
}
public List<String> getBoxesCreated() {
return boxesCreated;
}
}

View File

@ -1,122 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.ec2.provision;
import com.amazonaws.regions.Regions;
import org.deeplearning4j.aws.ec2.Ec2BoxCreator;
import org.kohsuke.args4j.CmdLineException;
import org.kohsuke.args4j.CmdLineParser;
import org.kohsuke.args4j.Option;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.threadly.concurrent.PriorityScheduler;
import java.util.List;
/**
* Sets up a DL4J cluster
* @author Adam Gibson
*
*/
public class ClusterSetup {
@Option(name = "-w", usage = "Number of workers")
private int numWorkers = 1;
@Option(name = "-ami", usage = "Amazon machine image: default, amazon linux (only works with RHEL right now")
private String ami = "ami-fb8e9292";
@Option(name = "-s", usage = "size of instance: default m1.medium")
private String size = "m3.xlarge";
@Option(name = "-sg", usage = "security group, this needs to be applyTransformToDestination")
private String securityGroupName;
@Option(name = "-kp", usage = "key pair name, also needs to be applyTransformToDestination.")
private String keyPairName;
@Option(name = "-kpath",
usage = "path to private key - needs to be applyTransformToDestination, this is used to login to amazon.")
private String pathToPrivateKey;
@Option(name = "-wscript", usage = "path to worker script to run, this will allow customization of dependencies")
private String workerSetupScriptPath;
@Option(name = "-mscript", usage = "path to master script to run this will allow customization of the dependencies")
private String masterSetupScriptPath;
@Option(name = "-region", usage = "specify a region")
private String region = Regions.US_EAST_1.getName();
private PriorityScheduler as;
private static final Logger log = LoggerFactory.getLogger(ClusterSetup.class);
public ClusterSetup(String[] args) {
CmdLineParser parser = new CmdLineParser(this);
try {
parser.parseArgument(args);
} catch (CmdLineException e) {
parser.printUsage(System.err);
log.error("Unable to parse args", e);
}
}
public void exec() {
//master + workers
Ec2BoxCreator boxCreator = new Ec2BoxCreator(ami, numWorkers, size, securityGroupName, keyPairName);
boxCreator.setRegion(Regions.fromName(region));
boxCreator.create();
boxCreator.blockTillAllRunning();
List<String> hosts = boxCreator.getHosts();
//provisionMaster(hosts.get(0));
provisionWorkers(hosts);
}
private void provisionWorkers(List<String> workers) {
as = new PriorityScheduler(Runtime.getRuntime().availableProcessors());
for (final String workerHost : workers) {
try {
as.execute(new Runnable() {
@Override
public void run() {
HostProvisioner uploader = new HostProvisioner(workerHost, "ec2-user");
try {
uploader.addKeyFile(pathToPrivateKey);
//uploader.runRemoteCommand("sudo hostname " + workerHost);
uploader.uploadAndRun(workerSetupScriptPath, "");
} catch (Exception e) {
e.printStackTrace();
}
}
});
} catch (Exception e) {
log.error("Error ", e);
}
}
}
/**
* @param args
*/
public static void main(String[] args) {
new ClusterSetup(args).exec();
}
}

View File

@ -1,31 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.ec2.provision;
public class DistributedDeepLearningTrainer {
private DistributedDeepLearningTrainer() {}
/**
* @param args
*/
public static void main(String[] args) {
ClusterSetup clusterSet = new ClusterSetup(args);
}
}

View File

@ -1,311 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.ec2.provision;
import com.jcraft.jsch.*;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.InputStream;
import java.util.Arrays;
import java.util.Collection;
/**
* Meant for uploading files to remote servers
* @author Adam Gibson
*
*/
public class HostProvisioner implements UserInfo {
private String host;
private JSch jsch;
private String user;
private int port = 22;
private String password;
private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class);
/**
*
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
* @param port the port to connect to(default 22)
*/
public HostProvisioner(String host, String user, String password, int port) {
super();
this.host = host;
this.user = user;
this.port = port;
this.password = password;
jsch = new JSch();
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
* @param password the password to use if any
*/
public HostProvisioner(String host, String user, String password) {
this(host, user, password, 22);
}
/**
* Connects to port 22
* @param host host to connect to (public facing dns)
* @param user the user to connect with (default root otherwise)
*/
public HostProvisioner(String host, String user) {
this(host, user, "", 22);
}
/**
* Connects to port 22, user root, with no password
* @param host host to connect to (public facing dns)
*/
public HostProvisioner(String host) {
this(host, "root", "", 22);
}
public void uploadAndRun(String script, String rootDir) throws Exception {
String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName();
upload(new File(script), remoteName);
String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName;
remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand;
runRemoteCommand(remoteCommand);
}
public void runRemoteCommand(String remoteCommand) throws Exception {
Session session = getSession();
session.connect();
ChannelExec channel = (ChannelExec) session.openChannel("exec");
channel.setCommand(remoteCommand);
channel.setErrStream(System.err);
channel.setPty(true);
channel.setOutputStream(System.out);
channel.connect();
channel.start();
InputStream input = channel.getInputStream();
//start reading the input from the executed commands on the shell
byte[] tmp = new byte[60000];
while (true) {
while (input.available() > 0) {
int i = input.read(tmp, 0, tmp.length);
if (i < 0)
break;
log.info(new String(tmp, 0, i));
}
if (channel.isClosed()) {
log.info("exit-status: " + channel.getExitStatus());
break;
}
}
channel.disconnect();
session.disconnect();
}
private Session getSession() throws Exception {
Session session = jsch.getSession(user, host, port);
session.setUserInfo(this);
return session;
}
/**
* Creates the directory for the file if necessary
* and uploads the file
* @param from the directory to upload from
* @param to the destination directory on the remote server
* @throws Exception
*/
public void uploadForDeployment(String from, String to) throws Exception {
File fromFile = new File(from);
if (!to.isEmpty() && fromFile.isDirectory())
mkDir(to);
else
upload(from, to);
}
public void addKeyFile(String keyFile) throws Exception {
jsch.addIdentity(keyFile);
}
//creates the directory to upload to
private void mkDir(String dir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
if (!fileExists(dir, c))
c.mkdir(dir);
c.exit();
session.disconnect();
}
private boolean fileExists(String dir, ChannelSftp channel) {
try {
channel.stat(dir);
return true;
} catch (Exception e) {
return false;
}
}
//uploads the file or listed files in a directory
private void upload(String fileOrDir, String uploadRootDir) throws Exception {
if (uploadRootDir.isEmpty())
uploadRootDir = ".";
File origin = new File(fileOrDir);
if (fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) {
upload(new File(fileOrDir), uploadRootDir);
untar(uploadRootDir);
} else if (origin.isFile()) {
upload(new File(fileOrDir), uploadRootDir);
} else {
File[] childFiles = origin.listFiles();
if (childFiles != null)
upload(Arrays.asList(childFiles), uploadRootDir);
}
}
private void untar(String targetRemoteFile) throws Exception {
this.runRemoteCommand("tar xvf " + targetRemoteFile);
}
private void upload(Collection<File> files, String rootDir) throws Exception {
Session session = getSession();
session.connect();
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
for (File f : files) {
if (f.isDirectory()) {
log.warn("Skipping " + f.getName());
continue;
}
log.info("Uploading " + f.getName());
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
c.put(bis, rootDir + "/" + f.getName());
bis.close();
}
channel.disconnect();
session.disconnect();
}
private void upload(File f, String remoteFile) throws Exception {
Session session = getSession();
int numRetries = 0;
while (numRetries < 3 && !session.isConnected()) {
try {
session.connect();
} catch (Exception e) {
numRetries++;
}
}
try {
Channel channel = session.openChannel("sftp");
channel.connect();
ChannelSftp c = (ChannelSftp) channel;
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
if (this.fileExists(remoteFile, c))
if (f.isDirectory())
c.rmdir(remoteFile);
else
c.rm(remoteFile);
c.put(bis, remoteFile);
bis.close();
c.exit();
session.disconnect();
} catch (Exception e) {
log.info("Session was down...trying again", e);
upload(f, remoteFile);
}
}
@Override
public String getPassphrase() {
return this.password;
}
@Override
public String getPassword() {
return this.password;
}
@Override
public boolean promptPassphrase(String arg0) {
return true;
}
@Override
public boolean promptPassword(String arg0) {
return true;
}
@Override
public boolean promptYesNo(String arg0) {
return true;
}
@Override
public void showMessage(String arg0) {
log.info(arg0);
}
}

View File

@ -1,47 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.emr;
import com.amazonaws.services.elasticmapreduce.model.Configuration;
import lombok.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
@Data
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
@Builder
public class EmrConfig {
protected String classification;
protected Map<String, String> properties;
protected List<EmrConfig> configs;
Configuration toAwsConfig() {
Configuration config = new Configuration().withClassification(classification).withProperties(properties);
List<Configuration> subConfigs = new ArrayList<>();
for (EmrConfig conf : configs){
subConfigs.add(conf.toAwsConfig());
}
return config.withConfigurations(subConfigs);
}
}

View File

@ -1,531 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.emr;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
import com.amazonaws.services.elasticmapreduce.model.*;
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
import com.amazonaws.services.s3.AmazonS3URI;
import com.amazonaws.services.s3.model.PutObjectRequest;
import lombok.AccessLevel;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NoArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.RandomStringUtils;
import org.nd4j.linalg.function.Function;
import java.io.File;
import java.util.*;
/**
* Configuration for a Spark EMR cluster
*/
@Data
@AllArgsConstructor(access = AccessLevel.PRIVATE)
@NoArgsConstructor
@Slf4j
public class SparkEMRClient {
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
protected String sparkAwsRegion = "us-east-1";
protected String sparkEmrRelease = "emr-5.9.0";
protected String sparkEmrServiceRole = "EMR_DefaultRole";
protected List<EmrConfig> sparkEmrConfigs = Collections.emptyList();
protected String sparkSubnetId = null;
protected List<String> sparkSecurityGroupIds = Collections.emptyList();
protected int sparkInstanceCount = 1;
protected String sparkInstanceType = "m3.xlarge";
protected Optional<Float> sparkInstanceBidPrice = Optional.empty();
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
protected String sparkS3JarFolder = "changeme";
protected int sparkTimeoutDurationMinutes = 90;
//underlying configs
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
protected RunJobFlowRequest sparkRunJobFlowRequest;
protected Function<PutObjectRequest, PutObjectRequest> sparkS3PutObjectDecorator;
protected Map<String, String> sparkSubmitConfs;
private static ClusterState[] activeClusterStates = new ClusterState[]{
ClusterState.RUNNING,
ClusterState.STARTING,
ClusterState.WAITING,
ClusterState.BOOTSTRAPPING};
private Optional<ClusterSummary> findClusterWithName(AmazonElasticMapReduce emr, String name) {
List<ClusterSummary> csrl = emr.listClusters((new ListClustersRequest()).withClusterStates(activeClusterStates)).getClusters();
for (ClusterSummary csr : csrl) {
if (csr.getName().equals(name)) return Optional.of(csr);
}
return Optional.empty();
}
/**
* Creates the current cluster
*/
public void createCluster() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional<ClusterSummary> csr = findClusterWithName(emr, sparkClusterName);
if (csr.isPresent()) {
String msg = String.format("A cluster with name %s and id %s is already deployed", sparkClusterName, csr.get().getId());
log.error(msg);
throw new IllegalStateException(msg);
} else {
RunJobFlowResult res = emr.runJobFlow(sparkRunJobFlowRequest);
String msg = String.format("Your cluster is launched with name %s and id %s.", sparkClusterName, res.getJobFlowId());
log.info(msg);
}
}
private void logClusters(List<ClusterSummary> csrl) {
if (csrl.isEmpty()) log.info("No cluster found.");
else {
log.info(String.format("%d clusters found.", csrl.size()));
for (ClusterSummary csr : csrl) {
log.info(String.format("Name: %s | Id: %s", csr.getName(), csr.getId()));
}
}
}
/**
* Lists existing active clusters Names
*
* @return cluster names
*/
public List<String> listActiveClusterNames() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
List<ClusterSummary> csrl =
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
logClusters(csrl);
List<String> res = new ArrayList<>(csrl.size());
for (ClusterSummary csr : csrl) res.add(csr.getName());
return res;
}
/**
* List existing active cluster IDs
*
* @return cluster IDs
*/
public List<String> listActiveClusterIds() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
List<ClusterSummary> csrl =
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
logClusters(csrl);
List<String> res = new ArrayList<>(csrl.size());
for (ClusterSummary csr : csrl) res.add(csr.getId());
return res;
}
/**
* Terminates a cluster
*/
public void terminateCluster() {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional<ClusterSummary> optClusterSum = findClusterWithName(emr, sparkClusterName);
if (!optClusterSum.isPresent()) {
log.error(String.format("The cluster with name %s , requested for deletion, does not exist.", sparkClusterName));
} else {
String id = optClusterSum.get().getId();
emr.terminateJobFlows((new TerminateJobFlowsRequest()).withJobFlowIds(id));
log.info(String.format("The cluster with id %s is terminating.", id));
}
}
// The actual job-sumission logic
private void submitJob(AmazonElasticMapReduce emr, String mainClass, List<String> args, Map<String, String> sparkConfs, File uberJar) throws Exception {
AmazonS3URI s3Jar = new AmazonS3URI(sparkS3JarFolder + "/" + uberJar.getName());
log.info(String.format("Placing uberJar %s to %s", uberJar.getPath(), s3Jar.toString()));
PutObjectRequest putRequest = sparkS3PutObjectDecorator.apply(
new PutObjectRequest(s3Jar.getBucket(), s3Jar.getKey(), uberJar)
);
sparkS3ClientBuilder.build().putObject(putRequest);
// The order of these matters
List<String> sparkSubmitArgs = Arrays.asList(
"spark-submit",
"--deploy-mode",
"cluster",
"--class",
mainClass
);
for (Map.Entry<String, String> e : sparkConfs.entrySet()) {
sparkSubmitArgs.add(String.format("--conf %s = %s ", e.getKey(), e.getValue()));
}
sparkSubmitArgs.add(s3Jar.toString());
sparkSubmitArgs.addAll(args);
StepConfig step = new StepConfig()
.withActionOnFailure(ActionOnFailure.CONTINUE)
.withName("Spark step")
.withHadoopJarStep(
new HadoopJarStepConfig()
.withJar("command-runner.jar")
.withArgs(sparkSubmitArgs)
);
Optional<ClusterSummary> optCsr = findClusterWithName(emr, sparkClusterName);
if (optCsr.isPresent()) {
ClusterSummary csr = optCsr.get();
emr.addJobFlowSteps(
new AddJobFlowStepsRequest()
.withJobFlowId(csr.getId())
.withSteps(step));
log.info(
String.format("Your job is added to the cluster with id %s.", csr.getId())
);
} else {
// If the cluster wasn't started, it's assumed ot be throwaway
List<StepConfig> steps = sparkRunJobFlowRequest.getSteps();
steps.add(step);
RunJobFlowRequest jobFlowRequest = sparkRunJobFlowRequest
.withSteps(steps)
.withInstances(sparkJobFlowInstancesConfig.withKeepJobFlowAliveWhenNoSteps(false));
RunJobFlowResult res = emr.runJobFlow(jobFlowRequest);
log.info("Your new cluster's id is %s.", res.getJobFlowId());
}
}
/**
* Submit a Spark Job with a specified main class
*/
public void sparkSubmitJobWithMain(String[] args, String mainClass, File uberJar) throws Exception {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
submitJob(emr, mainClass, Arrays.asList(args), sparkSubmitConfs, uberJar);
}
private void checkStatus(AmazonElasticMapReduce emr, String clusterId) throws InterruptedException {
log.info(".");
com.amazonaws.services.elasticmapreduce.model.Cluster dcr =
emr.describeCluster((new DescribeClusterRequest()).withClusterId(clusterId)).getCluster();
String state = dcr.getStatus().getState();
long timeOutTime = System.currentTimeMillis() + ((long) sparkTimeoutDurationMinutes * 60 * 1000);
Boolean activated = Arrays.asList(activeClusterStates).contains(ClusterState.fromValue(state));
Boolean timedOut = System.currentTimeMillis() > timeOutTime;
if (activated && timedOut) {
emr.terminateJobFlows(
new TerminateJobFlowsRequest().withJobFlowIds(clusterId)
);
log.error("Timeout. Cluster terminated.");
} else if (!activated) {
Boolean hasAbnormalStep = false;
StepSummary stepS = null;
List<StepSummary> steps = emr.listSteps(new ListStepsRequest().withClusterId(clusterId)).getSteps();
for (StepSummary step : steps) {
if (step.getStatus().getState() != StepState.COMPLETED.toString()) {
hasAbnormalStep = true;
stepS = step;
}
}
if (hasAbnormalStep && stepS != null)
log.error(String.format("Cluster %s terminated with an abnormal step, name %s, id %s", clusterId, stepS.getName(), stepS.getId()));
else
log.info("Cluster %s terminated without error.", clusterId);
} else {
Thread.sleep(5000);
checkStatus(emr, clusterId);
}
}
/**
* Monitor the cluster and terminates when it times out
*/
public void sparkMonitor() throws InterruptedException {
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
Optional<ClusterSummary> optCsr = findClusterWithName(emr, sparkClusterName);
if (!optCsr.isPresent()) {
log.error(String.format("The cluster with name %s does not exist.", sparkClusterName));
} else {
ClusterSummary csr = optCsr.get();
log.info(String.format("found cluster with id %s, starting monitoring", csr.getId()));
checkStatus(emr, csr.getId());
}
}
@Data
public static class Builder {
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
protected String sparkAwsRegion = "us-east-1";
protected String sparkEmrRelease = "emr-5.9.0";
protected String sparkEmrServiceRole = "EMR_DefaultRole";
protected List<EmrConfig> sparkEmrConfigs = Collections.emptyList();
protected String sparkSubNetid = null;
protected List<String> sparkSecurityGroupIds = Collections.emptyList();
protected int sparkInstanceCount = 1;
protected String sparkInstanceType = "m3.xlarge";
protected Optional<Float> sparkInstanceBidPrice = Optional.empty();
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
protected String sparkS3JarFolder = "changeme";
protected int sparkTimeoutDurationMinutes = 90;
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
protected RunJobFlowRequest sparkRunJobFlowRequest;
// This should allow the user to decorate the put call to add metadata to the jar put command, such as security groups,
protected Function<PutObjectRequest, PutObjectRequest> sparkS3PutObjectDecorator = new Function<PutObjectRequest, PutObjectRequest>() {
@Override
public PutObjectRequest apply(PutObjectRequest putObjectRequest) {
return putObjectRequest;
}
};
protected Map<String, String> sparkSubmitConfs;
/**
* Defines the EMR cluster's name
*
* @param clusterName the EMR cluster's name
* @return an EMR cluster builder
*/
public Builder clusterName(String clusterName) {
this.sparkClusterName = clusterName;
return this;
}
/**
* Defines the EMR cluster's region
* See https://docs.aws.amazon.com/general/latest/gr/rande.html
*
* @param region the EMR cluster's region
* @return an EMR cluster builder
*/
public Builder awsRegion(String region) {
this.sparkAwsRegion = region;
return this;
}
/**
* Defines the EMR release version to be used in this cluster
* uses a release label
* See https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-4.2.0/emr-release-differences.html#emr-release-label
*
* @param releaseLabel the EMR release label
* @return an EM cluster Builder
*/
public Builder emrRelease(String releaseLabel) {
this.sparkEmrRelease = releaseLabel;
return this;
}
/**
* Defines the IAM role to be assumed by the EMR service
* <p>
* https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-service.html
*
* @param serviceRole the service role
* @return an EM cluster Builder
*/
public Builder emrServiceRole(String serviceRole) {
this.sparkEmrServiceRole = serviceRole;
return this;
}
/**
* A list of configuration parameters to apply to EMR instances.
*
* @param configs the EMR configurations to apply to this cluster
* @return an EMR cluster builder
*/
public Builder emrConfigs(List<EmrConfig> configs) {
this.sparkEmrConfigs = configs;
return this;
}
/**
* The id of the EC2 subnet to be used for this Spark EMR service
* see https://docs.aws.amazon.com/AmazonVPC/latest/UserGuide/VPC_Subnets.html
*
* @param id the subnet ID
* @return an EMR cluster builder
*/
public Builder subnetId(String id) {
this.sparkSubNetid = id;
return this;
}
/**
* The id of additional security groups this deployment should adopt for both master and slaves
*
* @param securityGroups
* @return an EMR cluster builder
*/
public Builder securityGroupIDs(List<String> securityGroups) {
this.sparkSecurityGroupIds = securityGroups;
return this;
}
/**
* The number of instances this deployment should comprise of
*
* @param count the number of instances for this cluster
* @rturn an EMR cluster buidler
*/
public Builder instanceCount(int count) {
this.sparkInstanceCount = count;
return this;
}
/**
* The type of instance this cluster should comprise of
* See https://aws.amazon.com/ec2/instance-types/
*
* @param instanceType the type of instance for this cluster
* @return an EMR cluster builder
*/
public Builder instanceType(String instanceType) {
this.sparkInstanceType = instanceType;
return this;
}
/**
* The optional bid value for this cluster's spot instances
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html
* Uses the on-demand market if empty.
*
* @param optBid the Optional bid price for this cluster's instnces
* @return an EMR cluster Builder
*/
public Builder instanceBidPrice(Optional<Float> optBid) {
this.sparkInstanceBidPrice = optBid;
return this;
}
/**
* The EC2 instance role that this cluster's instances should assume
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
*
* @param role the intended instance role
* @return an EMR cluster builder
*/
public Builder instanceRole(String role) {
this.sparkInstanceRole = role;
return this;
}
/**
* the S3 folder in which to find the application jar
*
* @param jarfolder the S3 folder in which to find a jar
* @return an EMR cluster builder
*/
public Builder s3JarFolder(String jarfolder) {
this.sparkS3JarFolder = jarfolder;
return this;
}
/**
* The timeout duration for this Spark EMR cluster, in minutes
*
* @param timeoutMinutes
* @return an EMR cluster builder
*/
public Builder sparkTimeOutDurationMinutes(int timeoutMinutes) {
this.sparkTimeoutDurationMinutes = timeoutMinutes;
return this;
}
/**
* Creates an EMR Spark cluster deployment
*
* @return a SparkEMRClient
*/
public SparkEMRClient build() {
this.sparkEmrClientBuilder = AmazonElasticMapReduceClientBuilder.standard().withRegion(sparkAwsRegion);
this.sparkS3ClientBuilder = AmazonS3ClientBuilder.standard().withRegion(sparkAwsRegion);
// note this will be kept alive without steps, an arbitrary choice to avoid rapid test-teardown-restart cycles
this.sparkJobFlowInstancesConfig = (new JobFlowInstancesConfig()).withKeepJobFlowAliveWhenNoSteps(true);
if (this.sparkSubNetid != null)
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withEc2SubnetId(this.sparkSubNetid);
if (!this.sparkSecurityGroupIds.isEmpty()) {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalMasterSecurityGroups(this.sparkSecurityGroupIds);
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalSlaveSecurityGroups(this.sparkSecurityGroupIds);
}
InstanceGroupConfig masterConfig =
(new InstanceGroupConfig()).withInstanceCount(1).withInstanceRole(InstanceRoleType.MASTER).withInstanceType(sparkInstanceType);
if (sparkInstanceBidPrice.isPresent()) {
masterConfig = masterConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
} else {
masterConfig = masterConfig.withMarket(MarketType.ON_DEMAND);
}
int slaveCount = sparkInstanceCount - 1;
InstanceGroupConfig slaveConfig =
(new InstanceGroupConfig()).withInstanceCount(slaveCount).withInstanceRole(InstanceRoleType.CORE).withInstanceRole(sparkInstanceType);
if (sparkInstanceBidPrice.isPresent()) {
slaveConfig = slaveConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
} else {
slaveConfig = slaveConfig.withMarket(MarketType.ON_DEMAND);
}
if (slaveCount > 0) {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(Arrays.asList(masterConfig, slaveConfig));
} else {
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(slaveConfig);
}
this.sparkRunJobFlowRequest = new RunJobFlowRequest();
if (!sparkEmrConfigs.isEmpty()) {
List<Configuration> emrConfigs = new ArrayList<>();
for (EmrConfig config : sparkEmrConfigs) {
emrConfigs.add(config.toAwsConfig());
}
this.sparkRunJobFlowRequest = this.sparkRunJobFlowRequest.withConfigurations(emrConfigs);
}
this.sparkRunJobFlowRequest =
this.sparkRunJobFlowRequest.withName(sparkClusterName).withApplications((new Application()).withName("Spark"))
.withReleaseLabel(sparkEmrRelease)
.withServiceRole(sparkEmrServiceRole)
.withJobFlowRole(sparkInstanceRole)
.withInstances(this.sparkJobFlowInstancesConfig);
return new SparkEMRClient(
this.sparkClusterName,
this.sparkAwsRegion,
this.sparkEmrRelease,
this.sparkEmrServiceRole,
this.sparkEmrConfigs,
this.sparkSubNetid,
this.sparkSecurityGroupIds,
this.sparkInstanceCount,
this.sparkInstanceType,
this.sparkInstanceBidPrice,
this.sparkInstanceRole,
this.sparkS3JarFolder,
this.sparkTimeoutDurationMinutes,
this.sparkEmrClientBuilder,
this.sparkS3ClientBuilder,
this.sparkJobFlowInstancesConfig,
this.sparkRunJobFlowRequest,
this.sparkS3PutObjectDecorator,
this.sparkSubmitConfs
);
}
}
}

View File

@ -1,117 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3;
import com.amazonaws.auth.AWSCredentials;
import com.amazonaws.auth.BasicAWSCredentials;
import com.amazonaws.auth.PropertiesCredentials;
import com.amazonaws.services.ec2.AmazonEC2;
import com.amazonaws.services.ec2.AmazonEC2Client;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import java.io.File;
import java.io.InputStream;
/**
* The S3 Credentials works via discovering the credentials
* from the system properties (passed in via -D or System wide)
* If you invoke the JVM with -Dorg.deeplearning4j.aws.accessKey=YOUR_ACCESS_KEY
* and -Dorg.deeplearning4j.aws.accessSecret=YOUR_SECRET_KEY
* this will pick up the credentials from there, otherwise it will also attempt to look in
* the system environment for the following variables:
*
*
* AWS_ACCESS_KEY_ID
* AWS_SECRET_ACCESS_KEY
* @author Adam Gibson
*
*/
public abstract class BaseS3 {
/**
*
*/
protected static final long serialVersionUID = -2280107690193651289L;
protected String accessKey;
protected String secretKey;
protected AWSCredentials creds;
public final static String ACCESS_KEY = "org.deeplearning4j.aws.accessKey";
public final static String ACCESS_SECRET = "org.deeplearning4j.aws.accessSecret";
public final static String AWS_ACCESS_KEY = "AWS_ACCESS_KEY"; //"AWS_ACCESS_KEY_ID";
public final static String AWS_SECRET_KEY = "AWS_SECRET_KEY"; //"AWS_SECRET_ACCESS_KEY";
protected void findCreds() {
if (System.getProperty(ACCESS_KEY) != null && System.getProperty(ACCESS_SECRET) != null) {
accessKey = System.getProperty(ACCESS_KEY);
secretKey = System.getProperty(ACCESS_SECRET);
}
else if (System.getenv(AWS_ACCESS_KEY) != null && System.getenv(AWS_SECRET_KEY) != null) {
accessKey = System.getenv(AWS_ACCESS_KEY);
secretKey = System.getenv(AWS_SECRET_KEY);
}
}
public BaseS3() {
findCreds();
if (accessKey != null && secretKey != null)
creds = new BasicAWSCredentials(accessKey, secretKey);
if (creds == null)
throw new IllegalStateException("Unable to find ec2 credentials");
}
public BaseS3(File file) throws Exception {
if (accessKey != null && secretKey != null)
creds = new BasicAWSCredentials(accessKey, secretKey);
else
creds = new PropertiesCredentials(file);
}
public BaseS3(InputStream is) throws Exception {
if (accessKey != null && secretKey != null)
creds = new BasicAWSCredentials(accessKey, secretKey);
else
creds = new PropertiesCredentials(is);
}
public AWSCredentials getCreds() {
return creds;
}
public void setCreds(AWSCredentials creds) {
this.creds = creds;
}
public AmazonS3 getClient() {
return new AmazonS3Client(creds);
}
public AmazonEC2 getEc2() {
return new AmazonEC2Client(creds);
}
}

View File

@ -1,83 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3.reader;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import java.io.InputStream;
import java.util.Iterator;
import java.util.List;
/**
* baseline data applyTransformToDestination iterator for
* @author Adam Gibson
*
*/
public abstract class BaseS3DataSetIterator implements DataSetIterator {
/**
*
*/
private static final long serialVersionUID = 885205002006822431L;
private S3Downloader downloader;
private List<String> buckets;
private int currBucket;
private Iterator<InputStream> currIterator;
public BaseS3DataSetIterator() {
downloader = new S3Downloader();
buckets = downloader.buckets();
currBucket = 0;
currIterator = downloader.iterateBucket(buckets.get(currBucket));
}
public InputStream nextObject() {
if (currIterator.hasNext())
return currIterator.next();
else if (currBucket < buckets.size()) {
currBucket++;
currIterator = downloader.iterateBucket(buckets.get(currBucket));
return currIterator.next();
}
return null;
}
@Override
public boolean hasNext() {
return currBucket < buckets.size() && currIterator.hasNext();
}
public String currBucket() {
return buckets.get(currBucket);
}
public void nextBucket() {
currBucket++;
}
}

View File

@ -1,93 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3.reader;
import com.amazonaws.services.s3.model.ObjectListing;
import com.amazonaws.services.s3.model.S3ObjectSummary;
import java.io.InputStream;
import java.util.Iterator;
import java.util.List;
/**
* Iterator over individual S3 bucket
* @author Adam Gibson
*
*/
public class BucketIterator implements Iterator<InputStream> {
private S3Downloader s3;
private String bucket;
private ObjectListing currList;
private List<S3ObjectSummary> currObjects;
private int currObject;
public BucketIterator(String bucket) {
this(bucket, null);
}
public BucketIterator(String bucket, S3Downloader s3) {
this.bucket = bucket;
if (s3 == null)
this.s3 = new S3Downloader();
else
this.s3 = s3;
currList = this.s3.listObjects(bucket);
currObjects = currList.getObjectSummaries();
}
@Override
public boolean hasNext() {
return currObject < currObjects.size();
}
@Override
public InputStream next() {
if (currObject < currObjects.size()) {
InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey());
currObject++;
return ret;
} else if (currList.isTruncated()) {
currList = s3.nextList(currList);
currObjects = currList.getObjectSummaries();
currObject = 0;
InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey());
currObject++;
return ret;
}
throw new IllegalStateException("Indeterminate state");
}
@Override
public void remove() {
throw new UnsupportedOperationException();
}
}

View File

@ -1,38 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3.reader;
import com.amazonaws.services.s3.AmazonS3;
/**
* When paginating through a result applyTransformToDestination,
* allows the user to react to a bucket result being found
* @author Adam Gibson
*
*/
public interface BucketKeyListener {
/**
*
* @param s3 an s3 client
* @param bucket the bucket being iterated on
* @param key the current key
*/
void onKey(AmazonS3 s3, String bucket, String key);
}

View File

@ -1,178 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3.reader;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.*;
import com.amazonaws.services.s3.transfer.MultipleFileDownload;
import com.amazonaws.services.s3.transfer.TransferManager;
import org.apache.commons.io.IOUtils;
import org.deeplearning4j.aws.s3.BaseS3;
import java.io.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* Downloads files from S3
* @author Adam Gibson
*
*/
public class S3Downloader extends BaseS3 {
/**
* Return the keys for a bucket
* @param bucket the bucket to get the keys for
* @return the bucket's keys
*/
public List<String> keysForBucket(String bucket) {
AmazonS3 s3 = getClient();
List<String> ret = new ArrayList<>();
ListObjectsRequest listObjectsRequest = new ListObjectsRequest().withBucketName(bucket);
ObjectListing objectListing;
do {
objectListing = s3.listObjects(listObjectsRequest);
for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) {
ret.add(objectSummary.getKey());
}
listObjectsRequest.setMarker(objectListing.getNextMarker());
} while (objectListing.isTruncated());
return ret;
}
/**
* Returns the list of buckets in s3
* @return the list of buckets
*/
public List<String> buckets() {
List<String> ret = new ArrayList<>();
AmazonS3 s3 = getClient();
List<Bucket> buckets = s3.listBuckets();
for (Bucket b : buckets)
ret.add(b.getName());
return ret;
}
/**
* Iterate over individual buckets.
* Returns input streams to each object.
* It is your responsibility to close the input streams
* @param bucket the bucket to iterate over
* @return an iterator over the objects in an s3 bucket
*/
public Iterator<InputStream> iterateBucket(String bucket) {
return new BucketIterator(bucket, this);
}
/**
* Iterator style one list at a time
* @param list the list to get the next batch for
* @return the next batch of objects or null if
* none are left
*/
public ObjectListing nextList(ObjectListing list) {
AmazonS3 s3 = getClient();
if (list.isTruncated())
return s3.listNextBatchOfObjects(list);
return null;
}
/**
* Simple way of retrieving the listings for a bucket
* @param bucket the bucket to retrieve listings for
* @return the object listing for this bucket
*/
public ObjectListing listObjects(String bucket) {
AmazonS3 s3 = getClient();
ObjectListing list = s3.listObjects(bucket);
return list;
}
/**
* Paginates through a bucket's keys invoking the listener
* at each key
* @param bucket the bucket to iterate
* @param listener the listener
*/
public void paginate(String bucket, BucketKeyListener listener) {
AmazonS3 s3 = getClient();
ObjectListing list = s3.listObjects(bucket);
for (S3ObjectSummary summary : list.getObjectSummaries()) {
if (listener != null)
listener.onKey(s3, bucket, summary.getKey());
}
while (list.isTruncated()) {
list = s3.listNextBatchOfObjects(list);
for (S3ObjectSummary summary : list.getObjectSummaries()) {
if (listener != null)
listener.onKey(s3, bucket, summary.getKey());
}
}
}
/**
* Returns an input stream for the given bucket and key
* @param bucket the bucket to retrieve from
* @param key the key of the objec t
* @return an input stream to the object
*/
public InputStream objectForKey(String bucket, String key) {
AmazonS3 s3 = getClient();
S3Object obj = s3.getObject(bucket, key);
InputStream is = obj.getObjectContent();
return is;
}
public void download(String bucket, String key, File to) throws IOException {
AmazonS3 s3 = getClient();
S3Object obj = s3.getObject(bucket, key);
InputStream is = obj.getObjectContent();
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to));
IOUtils.copy(is, bos);
bos.close();
is.close();
obj.close();
}
public void download(String bucket, String key, OutputStream to) throws IOException {
AmazonS3 s3 = getClient();
S3Object obj = s3.getObject(bucket, key);
InputStream is = obj.getObjectContent();
BufferedOutputStream bos = new BufferedOutputStream(to);
IOUtils.copy(is, bos);
bos.close();
is.close();
obj.close();
}
public MultipleFileDownload downloadFolder(String bucketName, String keyPrefix, File folderPath) {
TransferManager transfer = new TransferManager(getClient());
return transfer.downloadDirectory(bucketName, keyPrefix, folderPath);
}
}

View File

@ -1,172 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.aws.s3.uploader;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.AmazonS3Client;
import com.amazonaws.services.s3.model.*;
import com.amazonaws.services.s3.transfer.MultipleFileUpload;
import com.amazonaws.services.s3.transfer.TransferManager;
import org.deeplearning4j.aws.s3.BaseS3;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
/**
* Uploads files to S3
*
* @see {@link BaseS3}
* @author Adam Gibson
*
*/
public class S3Uploader extends BaseS3 {
/**
* Multi part upload for big files
* @param file the file to upload
* @param bucketName the bucket name to upload
*/
public void multiPartUpload(File file, String bucketName) {
AmazonS3 client = new AmazonS3Client(creds);
bucketName = ensureValidBucketName(bucketName);
List<Bucket> buckets = client.listBuckets();
for (Bucket b : buckets)
if (b.getName().equals(bucketName)) {
doMultiPart(client, bucketName, file);
return;
}
//bucket didn't exist: create it
client.createBucket(bucketName);
doMultiPart(client, bucketName, file);
}
/**
* Upload the file to the bucket.
* Will create the bucket if it hasn't already been created
* @param file the file to upload
* @param bucketName the name of the bucket
*/
public void upload(File file, String bucketName) {
AmazonS3 client = new AmazonS3Client(creds);
bucketName = ensureValidBucketName(bucketName);
List<Bucket> buckets = client.listBuckets();
for (Bucket b : buckets)
if (b.getName().equals(bucketName)) {
client.putObject(bucketName, file.getName(), file);
return;
}
//bucket didn't exist: create it
client.createBucket(bucketName);
client.putObject(bucketName, file.getName(), file);
}
private void doMultiPart(AmazonS3 s3Client, String bucketName, File file) {
// Create a list of UploadPartResponse objects. You get one of these
// for each part upload.
List<PartETag> partETags = new ArrayList<>();
// Step 1: Initialize.
InitiateMultipartUploadRequest initRequest = new InitiateMultipartUploadRequest(bucketName, file.getName());
InitiateMultipartUploadResult initResponse = s3Client.initiateMultipartUpload(initRequest);
long contentLength = file.length();
long partSize = 5242880; // Set part size to 5 MB.
try {
// Step 2: Upload parts.
long filePosition = 0;
for (int i = 1; filePosition < contentLength; i++) {
// Last part can be less than 5 MB. Adjust part size.
partSize = Math.min(partSize, (contentLength - filePosition));
// Create request to upload a part.
UploadPartRequest uploadRequest = new UploadPartRequest().withBucketName(bucketName)
.withKey(file.getName()).withUploadId(initResponse.getUploadId()).withPartNumber(i)
.withFileOffset(filePosition).withFile(file).withPartSize(partSize);
// Upload part and add response to our list.
partETags.add(s3Client.uploadPart(uploadRequest).getPartETag());
filePosition += partSize;
}
// Step 3: Complete.
CompleteMultipartUploadRequest compRequest = new CompleteMultipartUploadRequest(bucketName, file.getName(),
initResponse.getUploadId(), partETags);
s3Client.completeMultipartUpload(compRequest);
} catch (Exception e) {
s3Client.abortMultipartUpload(
new AbortMultipartUploadRequest(bucketName, file.getName(), initResponse.getUploadId()));
}
}
private String ensureValidBucketName(String bucketName) {
String formatted = bucketName.replaceAll("\\s+", "_");
int length = bucketName.length();
if (length >= 62)
length = 62;
formatted = formatted.substring(0, length);
formatted = formatted.replace(".", "d");
formatted = formatted.toLowerCase();
if (formatted.endsWith("-"))
formatted = formatted.substring(0, length - 1);
return formatted;
}
public void upload(File file, String name, String bucketName) {
AmazonS3 client = getClient();
bucketName = ensureValidBucketName(bucketName);
List<Bucket> buckets = client.listBuckets();
// ObjectMetadata med = new ObjectMetadata();
// med.setContentLength(fileLength);
for (Bucket b : buckets)
if (b.getName().equals(bucketName)) {
//client.putObject(bucketName, name, is, med);
client.putObject(new PutObjectRequest(bucketName, name, file));
return;
}
//bucket didn't exist: createComplex it
client.createBucket(bucketName);
//client.putObject(bucketName, name, is, med);
client.putObject(new PutObjectRequest(bucketName, name, file));
}
public MultipleFileUpload uploadFolder(String bucketName, String keyPrefix, File folderPath,
boolean includeSubDir) {
TransferManager transfer = new TransferManager(getClient());
return transfer.uploadDirectory(bucketName, keyPrefix, folderPath, includeSubDir);
}
public MultipleFileUpload uploadFileList(String bucketName, File folderPath, List<File> fileList,
String keyPrefix) {
TransferManager transfer = new TransferManager(getClient());
return transfer.uploadFileList(bucketName, keyPrefix, folderPath, fileList);
}
}

View File

@ -1,45 +0,0 @@
#!/bin/bash
################################################################################
# Copyright (c) 2015-2018 Skymind, Inc.
#
# This program and the accompanying materials are made available under the
# terms of the Apache License, Version 2.0 which is available at
# https://www.apache.org/licenses/LICENSE-2.0.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# SPDX-License-Identifier: Apache-2.0
################################################################################
sudo yum -y install blas java-1.7.0-openjdk-devel
if [ ! -f "/opt/dl4j" ];
then
sudo mkdir /opt/dl4j
sudo yum -y install git
git clone https://github.com/agibsonccc/java-deeplearning
wget http://www.trieuvan.com/apache/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.tar.gz
sudo mv apache-maven-3.1.1-bin.tar.gz /opt
cd /opt && sudo tar xvf apache-maven-3.1.1-bin.tar.gz && sudo mv apache-maven-3.1.1 /opt/mvn
cd /home/ec2-user/java-deeplearning/ && /opt/mvn/bin/mvn -DskipTests clean install
echo "Printing distribution"
ls /home/ec2-user/java-deeplearning/deeplearning4j-distribution/target
echo "Before moving distribution"
sudo mv deeplearning4j-distribution/target/deeplearning4j-dist.tar.gz /opt
echo "Moving distribution to opt directory..."
echo "Moving in to opt directory"
cd /opt
sudo tar xzvf deeplearning4j-dist.tar.gz
#sudo mkdir /opt/dl4j
echo "Moving jars in to /opt/dl4j/..."
sudo mv *.jar /opt/dl4j
fi

View File

@ -28,7 +28,6 @@
<name>DeepLearning4j-scaleout-parent</name>
<modules>
<module>deeplearning4j-aws</module>
<module>spark</module>
<module>deeplearning4j-scaleout-parallelwrapper</module>
<module>deeplearning4j-scaleout-parallelwrapper-parameter-server</module>

View File

@ -1,58 +0,0 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>deeplearning4j-parent</artifactId>
<groupId>org.deeplearning4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>deeplearning4j-util</artifactId>
<packaging>jar</packaging>
<name>deeplearning4j-util</name>
<url>http://maven.apache.org</url>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
</properties>
<dependencies>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-api</artifactId>
<version>${nd4j.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common</artifactId>
<version>${nd4j.version}</version>
</dependency>
</dependencies>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,193 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.util;
import org.apache.commons.io.FileUtils;
import org.nd4j.linalg.util.SerializationUtils;
import java.io.File;
import java.io.IOException;
import java.io.Serializable;
import java.util.Collection;
import java.util.Iterator;
import java.util.Queue;
import java.util.UUID;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
/**
* Naive disk based queue for storing items on disk.
* Only meant for poll and adding items.
* @author Adam Gibson
*/
public class DiskBasedQueue<E> implements Queue<E>, Serializable {
private File dir;
private Queue<String> paths = new ConcurrentLinkedDeque<>();
private AtomicBoolean running = new AtomicBoolean(true);
private Queue<E> save = new ConcurrentLinkedDeque<>();
public DiskBasedQueue() {
this(".queue");
}
public DiskBasedQueue(String path) {
this(new File(path));
}
public DiskBasedQueue(File dir) {
this.dir = dir;
if (!dir.exists() && dir.isDirectory()) {
throw new IllegalArgumentException("Illegal queue: must be a directory");
}
if (!dir.exists())
dir.mkdirs();
if (dir.listFiles() != null && dir.listFiles().length > 1)
try {
FileUtils.deleteDirectory(dir);
} catch (IOException e) {
e.printStackTrace();
}
dir.mkdir();
Thread t = Executors.defaultThreadFactory().newThread(new Runnable() {
@Override
public void run() {
while (running.get()) {
while (!save.isEmpty())
addAndSave(save.poll());
ThreadUtils.uncheckedSleep(1000);
}
}
});
t.setName("DiskBasedQueueSaver");
t.setDaemon(true);
t.start();
}
@Override
public int size() {
return paths.size();
}
@Override
public boolean isEmpty() {
return paths.isEmpty();
}
@Override
public boolean contains(Object o) {
throw new UnsupportedOperationException();
}
@Override
public Iterator<E> iterator() {
throw new UnsupportedOperationException();
}
@Override
public Object[] toArray() {
throw new UnsupportedOperationException();
}
@Override
public <T> T[] toArray(T[] a) {
throw new UnsupportedOperationException();
}
@Override
public boolean add(E e) {
save.add(e);
return true;
}
@Override
public boolean remove(Object o) {
throw new UnsupportedOperationException();
}
@Override
public boolean containsAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean addAll(Collection<? extends E> c) {
for (E e : c)
addAndSave(e);
return true;
}
@Override
public boolean removeAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public boolean retainAll(Collection<?> c) {
throw new UnsupportedOperationException();
}
@Override
public void clear() {
throw new UnsupportedOperationException();
}
@Override
public boolean offer(E e) {
throw new UnsupportedOperationException();
}
@Override
public E remove() {
throw new UnsupportedOperationException();
}
@Override
public E poll() {
String path = paths.poll();
E ret = SerializationUtils.readObject(new File(path));
File item = new File(path);
item.delete();
return ret;
}
@Override
public E element() {
throw new UnsupportedOperationException();
}
@Override
public E peek() {
throw new UnsupportedOperationException();
}
private void addAndSave(E e) {
File path = new File(dir, UUID.randomUUID().toString());
SerializationUtils.saveObject(e, path);
paths.add(path.getAbsolutePath());
}
}

View File

@ -1,132 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.util;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Properties;
/**
* @author Adam Gibson
*/
public class Dl4jReflection {
private Dl4jReflection() {}
/**
* Gets the empty constructor from a class
* @param clazz the class to get the constructor from
* @return the empty constructor for the class
*/
public static Constructor<?> getEmptyConstructor(Class<?> clazz) {
Constructor<?> c = clazz.getDeclaredConstructors()[0];
for (int i = 0; i < clazz.getDeclaredConstructors().length; i++) {
if (clazz.getDeclaredConstructors()[i].getParameterTypes().length < 1) {
c = clazz.getDeclaredConstructors()[i];
break;
}
}
return c;
}
public static Field[] getAllFields(Class<?> clazz) {
// Keep backing up the inheritance hierarchy.
Class<?> targetClass = clazz;
List<Field> fields = new ArrayList<>();
do {
fields.addAll(Arrays.asList(targetClass.getDeclaredFields()));
targetClass = targetClass.getSuperclass();
} while (targetClass != null && targetClass != Object.class);
return fields.toArray(new Field[fields.size()]);
}
/**
* Sets the properties of the given object
* @param obj the object o set
* @param props the properties to set
*/
public static void setProperties(Object obj, Properties props) throws Exception {
for (Field field : obj.getClass().getDeclaredFields()) {
field.setAccessible(true);
if (props.containsKey(field.getName())) {
set(field, obj, props.getProperty(field.getName()));
}
}
}
/* sets a field with a fairly basic strategy */
private static void set(Field field, Object obj, String value) throws Exception {
Class<?> clazz = field.getType();
field.setAccessible(true);
if (clazz.equals(Double.class) || clazz.equals(double.class)) {
double val = Double.valueOf(value);
field.set(obj, val);
} else if (clazz.equals(String.class)) {
field.set(obj, value);
} else if (clazz.equals(Integer.class) || clazz.equals(int.class)) {
int val = Integer.parseInt(value);
field.set(obj, val);
} else if (clazz.equals(Float.class) || clazz.equals(float.class)) {
float f = Float.parseFloat(value);
field.set(obj, f);
}
}
/**
* Get fields as properties
* @param obj the object to get fields for
* @param clazzes the classes to use for reflection and properties.
* T
* @return the fields as properties
*/
public static Properties getFieldsAsProperties(Object obj, Class<?>[] clazzes) throws Exception {
Properties props = new Properties();
for (Field field : obj.getClass().getDeclaredFields()) {
if (Modifier.isStatic(field.getModifiers()))
continue;
field.setAccessible(true);
Class<?> type = field.getType();
if (clazzes == null || contains(type, clazzes)) {
Object val = field.get(obj);
if (val != null)
props.put(field.getName(), val.toString());
}
}
return props;
}
private static boolean contains(Class<?> test, Class<?>[] arr) {
for (Class<?> c : arr)
if (c.equals(test))
return true;
return false;
}
}

View File

@ -1,135 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<artifactId>deeplearning4j-parent</artifactId>
<groupId>org.deeplearning4j</groupId>
<version>1.0.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>dl4j-perf</artifactId>
<name>dl4j-perf</name>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<maven.compiler.source>1.7</maven.compiler.source>
<maven.compiler.target>1.7</maven.compiler.target>
</properties>
<dependencies>
<dependency>
<groupId>org.slf4j</groupId>
<artifactId>slf4j-api</artifactId>
</dependency>
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-json</artifactId>
<version>${oshi.version}</version>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-nn</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.github.oshi</groupId>
<artifactId>oshi-core</artifactId>
<version>${oshi.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>${lombok.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>ch.qos.logback</groupId>
<artifactId>logback-classic</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-datasets</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.deeplearning4j</groupId>
<artifactId>deeplearning4j-common-tests</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
<plugins>
<plugin>
<artifactId>maven-clean-plugin</artifactId>
<version>3.0.0</version>
</plugin>
<!-- see http://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.7.0</version>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.20.1</version>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>3.0.2</version>
</plugin>
<plugin>
<artifactId>maven-install-plugin</artifactId>
<version>2.5.2</version>
</plugin>
<plugin>
<artifactId>maven-deploy-plugin</artifactId>
<version>2.8.2</version>
</plugin>
</plugins>
</pluginManagement>
</build>
<profiles>
<profile>
<id>test-nd4j-native</id>
</profile>
<profile>
<id>test-nd4j-cuda-10.2</id>
</profile>
</profiles>
</project>

View File

@ -1,50 +0,0 @@
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~ Copyright (c) 2015-2018 Skymind, Inc.
~
~ This program and the accompanying materials are made available under the
~ terms of the Apache License, Version 2.0 which is available at
~ https://www.apache.org/licenses/LICENSE-2.0.
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
~ License for the specific language governing permissions and limitations
~ under the License.
~
~ SPDX-License-Identifier: Apache-2.0
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
<configuration>
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
<file>logs/application.log</file>
<encoder>
<pattern>%date - [%level] - from %logger in %thread
%n%message%n%xException%n</pattern>
</encoder>
</appender>
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
<encoder>
<pattern> %logger{15} - %message%n%xException{5}
</pattern>
</encoder>
</appender>
<logger name="org.apache.catalina.core" level="DEBUG" />
<logger name="org.springframework" level="DEBUG" />
<logger name="org.deeplearning4j" level="INFO" />
<logger name="org.datavec" level="INFO" />
<logger name="org.nd4j" level="INFO" />
<logger name="opennlp.uima.util" level="OFF" />
<logger name="org.apache.uima" level="OFF" />
<logger name="org.cleartk" level="OFF" />
<root level="ERROR">
<appender-ref ref="STDOUT" />
<appender-ref ref="FILE" />
</root>
</configuration>

View File

@ -140,8 +140,6 @@
<module>deeplearning4j-nearestneighbors-parent</module>
<module>deeplearning4j-data</module>
<module>deeplearning4j-manifold</module>
<module>deeplearning4j-util</module>
<module>dl4j-perf</module>
<module>dl4j-integration-tests</module>
<module>deeplearning4j-common</module>
<module>deeplearning4j-remote</module>

View File

@ -39,6 +39,7 @@ do
done
CHIP="${CHIP:-cpu}"
export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
# On Mac, make sure it can find libraries for GCC
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
@ -47,8 +48,9 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/
if [ -n "$BUILD_PATH" ]; then
if which cygpath; then
BUILD_PATH=$(cygpath -p $BUILD_PATH)
export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'"
fi
export PATH="$PATH:$BUILD_PATH"
fi
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests --gtest_output="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests

View File

@ -217,16 +217,6 @@
<artifactId>commons-net</artifactId>
<version>${commons-net.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-buffer</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-context</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>net.ericaro</groupId>
<artifactId>neoitertools</artifactId>
@ -238,6 +228,16 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.nd4j</groupId>
<artifactId>nd4j-common</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.bytedeco</groupId>
<artifactId>javacpp</artifactId>
<version>${javacpp.version}</version>
</dependency>
</dependencies>
<profiles>

View File

@ -81,11 +81,6 @@ public class BroadcastMax extends BaseBroadcastOp {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
return "Max";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;

View File

@ -81,11 +81,6 @@ public class BroadcastMin extends BaseBroadcastOp {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
return "Min";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;

View File

@ -75,11 +75,6 @@ public class BroadcastMulOp extends BaseBroadcastOp {
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
}
@Override
public String tensorflowName() {
return "Mul";
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;

View File

@ -75,11 +75,6 @@ public class BroadcastSubOp extends BaseBroadcastOp {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName(){
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
return null;

Some files were not shown because too many files have changed in this diff Show More