Cleanup of multiple projects (#175)
* Cleanup modules * Moving subprojects to nd4j-api * Project cleanup * Dropped AWS sub-project * dl4j-util moved to core * dl4j-perf moved to core * Tests coverage * Revert "Moving subprojects to nd4j-api" This reverts commit bc6eb573c6b60c407ade47172c5d204725077e6b. * Moved nd4j-buffer and nd4j-context to nd4j-api * Rolled back change * Revert "Project cleanup" This reverts commit 64ac7f369b2d968f7be437718034f093fc886ffc. * Datavec cleaned up * Revert "Moved nd4j-buffer and nd4j-context to nd4j-api" This reverts commit 75f4e8da80d2551e44e1251dd6c5923289fff8e1. # Conflicts: # nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionBpOpValidation.java * Resolve conflict * Compilation fixed. * nd4j-context and nd4j-buffer moved to nd4j-api * Fixed TF mapping for mmul * Fix for dl4j-cuda tests Signed-off-by: Alex Black <blacka101@gmail.com> * Move last few tests from deeplearning4j-nn to -core Signed-off-by: Alex Black <blacka101@gmail.com> * Remove incorrect TF import mapping for TensorMmul op Signed-off-by: Alex Black <blacka101@gmail.com> * Cleaned TF mapping * Fix path for test results on windows * Remove old dependency Signed-off-by: Alex Black <blacka101@gmail.com> * One more attempt to fix path for test results on windows * fixup! One more attempt to fix path for test results on windows * fixup! One more attempt to fix path for test results on windows Co-authored-by: Alex Black <blacka101@gmail.com> Co-authored-by: Serhii Shepel <9946053+sshepel@users.noreply.github.com> Co-authored-by: raver119 <raver119@gmail.com>master
parent
99a54829c2
commit
4db28a9300
|
@ -1,116 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>datavec-parent</artifactId>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>datavec-camel</artifactId>
|
|
||||||
|
|
||||||
<name>DataVec Camel Component</name>
|
|
||||||
<url>http://deeplearning4j.org</url>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.camel</groupId>
|
|
||||||
<artifactId>camel-csv</artifactId>
|
|
||||||
<version>${camel.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.camel</groupId>
|
|
||||||
<artifactId>camel-core</artifactId>
|
|
||||||
<version>${camel.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<!-- support camel documentation -->
|
|
||||||
<!-- logging -->
|
|
||||||
<!-- testing -->
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.camel</groupId>
|
|
||||||
<artifactId>camel-test</artifactId>
|
|
||||||
<version>${camel.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<defaultGoal>install</defaultGoal>
|
|
||||||
|
|
||||||
<plugins>
|
|
||||||
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<source>1.7</source>
|
|
||||||
<target>1.7</target>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-resources-plugin</artifactId>
|
|
||||||
<version>3.0.1</version>
|
|
||||||
<configuration>
|
|
||||||
<encoding>UTF-8</encoding>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
|
|
||||||
<!-- generate components meta-data and validate component includes documentation etc -->
|
|
||||||
<plugin>
|
|
||||||
<groupId>org.apache.camel</groupId>
|
|
||||||
<artifactId>camel-package-maven-plugin</artifactId>
|
|
||||||
<version>${camel.version}</version>
|
|
||||||
<executions>
|
|
||||||
<execution>
|
|
||||||
<id>prepare</id>
|
|
||||||
<goals>
|
|
||||||
<goal>prepare-components</goal>
|
|
||||||
</goals>
|
|
||||||
<phase>generate-resources</phase>
|
|
||||||
</execution>
|
|
||||||
<!-- <execution>
|
|
||||||
<id>validate</id>
|
|
||||||
<goals>
|
|
||||||
<goal>validate-components</goal>
|
|
||||||
</goals>
|
|
||||||
<phase>prepare-package</phase>
|
|
||||||
</execution>-->
|
|
||||||
</executions>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,45 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import org.apache.camel.CamelContext;
|
|
||||||
import org.apache.camel.Endpoint;
|
|
||||||
import org.apache.camel.impl.UriEndpointComponent;
|
|
||||||
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents the component that manages {@link DataVecEndpoint}.
|
|
||||||
*/
|
|
||||||
public class DataVecComponent extends UriEndpointComponent {
|
|
||||||
|
|
||||||
public DataVecComponent() {
|
|
||||||
super(DataVecEndpoint.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataVecComponent(CamelContext context) {
|
|
||||||
super(context, DataVecEndpoint.class);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected Endpoint createEndpoint(String uri, String remaining, Map<String, Object> parameters) throws Exception {
|
|
||||||
DataVecEndpoint endpoint = new DataVecEndpoint(uri, this);
|
|
||||||
setProperties(endpoint, parameters);
|
|
||||||
endpoint.setInputFormat(remaining);
|
|
||||||
return endpoint;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
|
|
||||||
import org.apache.camel.Exchange;
|
|
||||||
import org.apache.camel.Processor;
|
|
||||||
import org.apache.camel.impl.ScheduledPollConsumer;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.formats.input.InputFormat;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The DataVec consumer.
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class DataVecConsumer extends ScheduledPollConsumer {
|
|
||||||
private final DataVecEndpoint endpoint;
|
|
||||||
private Class<? extends InputFormat> inputFormatClazz;
|
|
||||||
private Class<? extends DataVecMarshaller> marshallerClazz;
|
|
||||||
private InputFormat inputFormat;
|
|
||||||
private Configuration configuration;
|
|
||||||
private DataVecMarshaller marshaller;
|
|
||||||
|
|
||||||
|
|
||||||
public DataVecConsumer(DataVecEndpoint endpoint, Processor processor) {
|
|
||||||
super(endpoint, processor);
|
|
||||||
this.endpoint = endpoint;
|
|
||||||
|
|
||||||
try {
|
|
||||||
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
|
|
||||||
inputFormat = inputFormatClazz.newInstance();
|
|
||||||
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
|
|
||||||
marshaller = marshallerClazz.newInstance();
|
|
||||||
configuration = new Configuration();
|
|
||||||
for (String prop : endpoint.getConsumerProperties().keySet())
|
|
||||||
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
|
|
||||||
protected InputSplit inputFromExchange(Exchange exchange) {
|
|
||||||
return marshaller.getSplit(exchange);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected int poll() throws Exception {
|
|
||||||
Exchange exchange = endpoint.createExchange();
|
|
||||||
InputSplit split = inputFromExchange(exchange);
|
|
||||||
RecordReader reader = inputFormat.createReader(split, configuration);
|
|
||||||
int numMessagesPolled = 0;
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
// create a message body
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
exchange.getIn().setBody(reader.next());
|
|
||||||
|
|
||||||
try {
|
|
||||||
// send message to next processor in the route
|
|
||||||
getProcessor().process(exchange);
|
|
||||||
numMessagesPolled++; // number of messages polled
|
|
||||||
} finally {
|
|
||||||
// log exception if an exception occurred and was not handled
|
|
||||||
if (exchange.getException() != null) {
|
|
||||||
getExceptionHandler().handleException("Error processing exchange", exchange,
|
|
||||||
exchange.getException());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return numMessagesPolled;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,68 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import lombok.Data;
|
|
||||||
import org.apache.camel.Consumer;
|
|
||||||
import org.apache.camel.Processor;
|
|
||||||
import org.apache.camel.Producer;
|
|
||||||
import org.apache.camel.impl.DefaultEndpoint;
|
|
||||||
import org.apache.camel.spi.Metadata;
|
|
||||||
import org.apache.camel.spi.UriEndpoint;
|
|
||||||
import org.apache.camel.spi.UriParam;
|
|
||||||
import org.apache.camel.spi.UriPath;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Represents a DataVec endpoint.
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@UriEndpoint(scheme = "datavec", title = "datavec", syntax = "datavec:inputFormat/?outputFormat=?&inputMarshaller=?",
|
|
||||||
consumerClass = DataVecConsumer.class, label = "datavec")
|
|
||||||
@Data
|
|
||||||
public class DataVecEndpoint extends DefaultEndpoint {
|
|
||||||
@UriPath
|
|
||||||
@Metadata(required = "true")
|
|
||||||
private String inputFormat;
|
|
||||||
@UriParam(defaultValue = "")
|
|
||||||
private String outputFormat;
|
|
||||||
@UriParam
|
|
||||||
@Metadata(required = "true")
|
|
||||||
private String inputMarshaller;
|
|
||||||
@UriParam(defaultValue = "org.datavec.api.io.converters.SelfWritableConverter")
|
|
||||||
private String writableConverter;
|
|
||||||
|
|
||||||
public DataVecEndpoint(String uri, DataVecComponent component) {
|
|
||||||
super(uri, component);
|
|
||||||
}
|
|
||||||
|
|
||||||
public DataVecEndpoint(String endpointUri) {
|
|
||||||
super(endpointUri);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Producer createProducer() throws Exception {
|
|
||||||
return new DataVecProducer(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
public Consumer createConsumer(Processor processor) throws Exception {
|
|
||||||
return new DataVecConsumer(this, processor);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isSingleton() {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,36 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import org.apache.camel.Exchange;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Marshals na exchange in to an input split
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public interface DataVecMarshaller {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param exchange
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
InputSplit getSplit(Exchange exchange);
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,109 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import org.apache.camel.Exchange;
|
|
||||||
import org.apache.camel.impl.DefaultProducer;
|
|
||||||
import org.datavec.api.conf.Configuration;
|
|
||||||
import org.datavec.api.formats.input.InputFormat;
|
|
||||||
import org.datavec.api.io.WritableConverter;
|
|
||||||
import org.datavec.api.io.converters.SelfWritableConverter;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The DataVec producer.
|
|
||||||
* Converts input records in to their final form
|
|
||||||
* based on the input split generated from
|
|
||||||
* the given exchange.
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class DataVecProducer extends DefaultProducer {
|
|
||||||
private Class<? extends InputFormat> inputFormatClazz;
|
|
||||||
private Class<? extends DataVecMarshaller> marshallerClazz;
|
|
||||||
private InputFormat inputFormat;
|
|
||||||
private Configuration configuration;
|
|
||||||
private WritableConverter writableConverter;
|
|
||||||
private DataVecMarshaller marshaller;
|
|
||||||
|
|
||||||
|
|
||||||
public DataVecProducer(DataVecEndpoint endpoint) {
|
|
||||||
super(endpoint);
|
|
||||||
if (endpoint.getInputFormat() != null) {
|
|
||||||
try {
|
|
||||||
inputFormatClazz = (Class<? extends InputFormat>) Class.forName(endpoint.getInputFormat());
|
|
||||||
inputFormat = inputFormatClazz.newInstance();
|
|
||||||
marshallerClazz = (Class<? extends DataVecMarshaller>) Class.forName(endpoint.getInputMarshaller());
|
|
||||||
Class<? extends WritableConverter> converterClazz =
|
|
||||||
(Class<? extends WritableConverter>) Class.forName(endpoint.getWritableConverter());
|
|
||||||
writableConverter = converterClazz.newInstance();
|
|
||||||
marshaller = marshallerClazz.newInstance();
|
|
||||||
configuration = new Configuration();
|
|
||||||
for (String prop : endpoint.getConsumerProperties().keySet())
|
|
||||||
configuration.set(prop, endpoint.getConsumerProperties().get(prop).toString());
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//stub, still need to fill out more of the end point yet..endpoint will likely be initialized with a split
|
|
||||||
protected InputSplit inputFromExchange(Exchange exchange) {
|
|
||||||
return marshaller.getSplit(exchange);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void process(Exchange exchange) throws Exception {
|
|
||||||
InputSplit split = inputFromExchange(exchange);
|
|
||||||
RecordReader reader = inputFormat.createReader(split, configuration);
|
|
||||||
Collection<Collection<Writable>> newRecord = new ArrayList<>();
|
|
||||||
if (!(writableConverter instanceof SelfWritableConverter)) {
|
|
||||||
newRecord = new ArrayList<>();
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
Collection<Writable> newRecordAdd = new ArrayList<>();
|
|
||||||
// create a message body
|
|
||||||
Collection<Writable> next = reader.next();
|
|
||||||
for (Writable writable : next) {
|
|
||||||
newRecordAdd.add(writableConverter.convert(writable));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
newRecord.add(newRecordAdd);
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
while (reader.hasNext()) {
|
|
||||||
// create a message body
|
|
||||||
Collection<Writable> next = reader.next();
|
|
||||||
newRecord.add(next);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
exchange.getIn().setBody(newRecord);
|
|
||||||
exchange.getOut().setBody(newRecord);
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,42 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component.csv.marshaller;
|
|
||||||
|
|
||||||
import org.apache.camel.Exchange;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.split.ListStringSplit;
|
|
||||||
import org.datavec.camel.component.DataVecMarshaller;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Marshals List<List<String>>
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class ListStringInputMarshaller implements DataVecMarshaller {
|
|
||||||
/**
|
|
||||||
* @param exchange
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public InputSplit getSplit(Exchange exchange) {
|
|
||||||
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
|
|
||||||
InputSplit listSplit = new ListStringSplit(data);
|
|
||||||
return listSplit;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1 +0,0 @@
|
||||||
class=org.datavec.camel.component.DataVecComponent
|
|
|
@ -1,82 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import org.apache.camel.builder.RouteBuilder;
|
|
||||||
import org.apache.camel.component.mock.MockEndpoint;
|
|
||||||
import org.apache.camel.test.junit4.CamelTestSupport;
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
|
|
||||||
import org.datavec.api.split.FileSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.junit.BeforeClass;
|
|
||||||
import org.junit.ClassRule;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.junit.rules.TemporaryFolder;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
public class DataVecComponentTest extends CamelTestSupport {
|
|
||||||
|
|
||||||
@ClassRule
|
|
||||||
public static TemporaryFolder testDir = new TemporaryFolder();
|
|
||||||
private static File dir;
|
|
||||||
private static File irisFile;
|
|
||||||
|
|
||||||
|
|
||||||
@BeforeClass
|
|
||||||
public static void before() throws Exception {
|
|
||||||
dir = testDir.newFolder();
|
|
||||||
File iris = new ClassPathResource("iris.dat").getFile();
|
|
||||||
irisFile = new File(dir, "iris.dat");
|
|
||||||
FileUtils.copyFile(iris, irisFile );
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testDataVec() throws Exception {
|
|
||||||
MockEndpoint mock = getMockEndpoint("mock:result");
|
|
||||||
//1
|
|
||||||
mock.expectedMessageCount(1);
|
|
||||||
|
|
||||||
RecordReader reader = new CSVRecordReader();
|
|
||||||
reader.initialize(new FileSplit(new ClassPathResource("iris.dat").getFile()));
|
|
||||||
Collection<Collection<Writable>> recordAssertion = new ArrayList<>();
|
|
||||||
while (reader.hasNext())
|
|
||||||
recordAssertion.add(reader.next());
|
|
||||||
mock.expectedBodiesReceived(recordAssertion);
|
|
||||||
assertMockEndpointsSatisfied();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected RouteBuilder createRouteBuilder() throws Exception {
|
|
||||||
|
|
||||||
|
|
||||||
return new RouteBuilder() {
|
|
||||||
public void configure() {
|
|
||||||
from("file:" + dir.getAbsolutePath() + "?fileName=iris.dat&noop=true").unmarshal().csv()
|
|
||||||
.to("datavec://org.datavec.api.formats.input.impl.ListStringInputFormat?inputMarshaller=org.datavec.camel.component.ListStringInputMarshaller&writableConverter=org.datavec.api.io.converters.SelfWritableConverter")
|
|
||||||
.to("mock:result");
|
|
||||||
}
|
|
||||||
};
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,41 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.camel.component;
|
|
||||||
|
|
||||||
import org.apache.camel.Exchange;
|
|
||||||
import org.datavec.api.split.InputSplit;
|
|
||||||
import org.datavec.api.split.ListStringSplit;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Marshals List<List<String>>
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class ListStringInputMarshaller implements DataVecMarshaller {
|
|
||||||
/**
|
|
||||||
* @param exchange
|
|
||||||
* @return
|
|
||||||
*/
|
|
||||||
@Override
|
|
||||||
public InputSplit getSplit(Exchange exchange) {
|
|
||||||
List<List<String>> data = (List<List<String>>) exchange.getIn().getBody();
|
|
||||||
InputSplit listSplit = new ListStringSplit(data);
|
|
||||||
return listSplit;
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,30 +0,0 @@
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
#
|
|
||||||
# The logging properties used
|
|
||||||
#
|
|
||||||
log4j.rootLogger=INFO, out
|
|
||||||
|
|
||||||
# uncomment the following line to turn on Camel debugging
|
|
||||||
#log4j.logger.org.apache.camel=DEBUG
|
|
||||||
|
|
||||||
# CONSOLE appender not used by default
|
|
||||||
log4j.appender.out=org.apache.log4j.ConsoleAppender
|
|
||||||
log4j.appender.out.layout=org.apache.log4j.PatternLayout
|
|
||||||
log4j.appender.out.layout.ConversionPattern=[%30.30t] %-30.30c{1} %-5p %m%n
|
|
||||||
#log4j.appender.out.layout.ConversionPattern=%d [%-15.15t] %-5p %-30.30c{1} - %m%n
|
|
||||||
|
|
|
@ -37,11 +37,6 @@
|
||||||
<version>${logback.version}</version>
|
<version>${logback.version}</version>
|
||||||
<scope>test</scope>
|
<scope>test</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-buffer</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>com.github.jai-imageio</groupId>
|
<groupId>com.github.jai-imageio</groupId>
|
||||||
<artifactId>jai-imageio-core</artifactId>
|
<artifactId>jai-imageio-core</artifactId>
|
||||||
|
|
|
@ -1,65 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
|
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>datavec-parent</artifactId>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>datavec-perf</artifactId>
|
|
||||||
|
|
||||||
<name>datavec-perf</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
|
||||||
<maven.compiler.source>1.7</maven.compiler.source>
|
|
||||||
<maven.compiler.target>1.7</maven.compiler.target>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-api</artifactId>
|
|
||||||
<version>${slf4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-data-image</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.datavec</groupId>
|
|
||||||
<artifactId>datavec-api</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,112 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.perf.timing;
|
|
||||||
|
|
||||||
import lombok.val;
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.split.InputStreamInputSplit;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.performance.PerformanceTracker;
|
|
||||||
import org.nd4j.linalg.memory.MemcpyDirection;
|
|
||||||
|
|
||||||
import java.io.BufferedInputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Timing components of a data vec pipeline
|
|
||||||
* consisting of:
|
|
||||||
* {@link RecordReader}, {@link InputStreamInputSplit}
|
|
||||||
* (note that this uses input stream input split,
|
|
||||||
* the record reader must support {@link InputStreamInputSplit} for this to work)
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class IOTiming {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns statistics for components of a datavec pipeline
|
|
||||||
* averaged over the specified number of times
|
|
||||||
* @param nTimes the number of times to run the pipeline for averaging
|
|
||||||
* @param recordReader the record reader
|
|
||||||
* @param file the file to read
|
|
||||||
* @param function the function
|
|
||||||
* @return the averaged {@link TimingStatistics} for input/output on a record
|
|
||||||
* reader and ndarray creation (based on the given function
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
public static TimingStatistics averageFileRead(long nTimes, RecordReader recordReader, File file, INDArrayCreationFunction function) throws Exception {
|
|
||||||
TimingStatistics timingStatistics = null;
|
|
||||||
for(int i = 0; i < nTimes; i++) {
|
|
||||||
TimingStatistics timingStatistics1 = timeNDArrayCreation(recordReader,new BufferedInputStream(new FileInputStream(file)),function);
|
|
||||||
if(timingStatistics == null)
|
|
||||||
timingStatistics = timingStatistics1;
|
|
||||||
else {
|
|
||||||
timingStatistics = timingStatistics.add(timingStatistics1);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return timingStatistics.average(nTimes);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param reader
|
|
||||||
* @param inputStream
|
|
||||||
* @param function
|
|
||||||
* @return
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
public static TimingStatistics timeNDArrayCreation(RecordReader reader,
|
|
||||||
InputStream inputStream,
|
|
||||||
INDArrayCreationFunction function) throws Exception {
|
|
||||||
|
|
||||||
|
|
||||||
reader.initialize(new InputStreamInputSplit(inputStream));
|
|
||||||
long longNanos = System.nanoTime();
|
|
||||||
List<Writable> next = reader.next();
|
|
||||||
long endNanos = System.nanoTime();
|
|
||||||
long etlDiff = endNanos - longNanos;
|
|
||||||
long startArrCreation = System.nanoTime();
|
|
||||||
INDArray arr = function.createFromRecord(next);
|
|
||||||
long endArrCreation = System.nanoTime();
|
|
||||||
long endCreationDiff = endArrCreation - startArrCreation;
|
|
||||||
Map<Integer, Map<MemcpyDirection, Long>> currentBandwidth = PerformanceTracker.getInstance().getCurrentBandwidth();
|
|
||||||
val bw = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
|
|
||||||
val deviceToHost = currentBandwidth.get(0).get(MemcpyDirection.HOST_TO_DEVICE);
|
|
||||||
|
|
||||||
return TimingStatistics.builder()
|
|
||||||
.diskReadingTimeNanos(etlDiff)
|
|
||||||
.bandwidthNanosHostToDevice(bw)
|
|
||||||
.bandwidthDeviceToHost(deviceToHost)
|
|
||||||
.ndarrayCreationTimeNanos(endCreationDiff)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
public interface INDArrayCreationFunction {
|
|
||||||
INDArray createFromRecord(List<Writable> record);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,74 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.perf.timing;
|
|
||||||
|
|
||||||
import lombok.Builder;
|
|
||||||
import lombok.Data;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The timing statistics for a data pipeline including:
|
|
||||||
* ndarray creation time in nanoseconds
|
|
||||||
* disk reading time in nanoseconds
|
|
||||||
* bandwidth used in host to device in nano seconds
|
|
||||||
* bandwidth device to host in nanoseconds
|
|
||||||
*
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
@Builder
|
|
||||||
@Data
|
|
||||||
public class TimingStatistics {
|
|
||||||
|
|
||||||
private long ndarrayCreationTimeNanos;
|
|
||||||
private long diskReadingTimeNanos;
|
|
||||||
private long bandwidthNanosHostToDevice;
|
|
||||||
private long bandwidthDeviceToHost;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Accumulate the given statistics
|
|
||||||
* @param timingStatistics the statistics to add
|
|
||||||
* @return the added statistics
|
|
||||||
*/
|
|
||||||
public TimingStatistics add(TimingStatistics timingStatistics) {
|
|
||||||
return TimingStatistics.builder()
|
|
||||||
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos + timingStatistics.ndarrayCreationTimeNanos)
|
|
||||||
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice + timingStatistics.bandwidthNanosHostToDevice)
|
|
||||||
.diskReadingTimeNanos(diskReadingTimeNanos + timingStatistics.diskReadingTimeNanos)
|
|
||||||
.bandwidthDeviceToHost(bandwidthDeviceToHost + timingStatistics.bandwidthDeviceToHost)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Average the results relative to the number of n.
|
|
||||||
* This method is meant to be used alongside
|
|
||||||
* {@link #add(TimingStatistics)}
|
|
||||||
* accumulated a number of times
|
|
||||||
* @param n n the number of elements
|
|
||||||
* @return the averaged results
|
|
||||||
*/
|
|
||||||
public TimingStatistics average(long n) {
|
|
||||||
return TimingStatistics.builder()
|
|
||||||
.ndarrayCreationTimeNanos(ndarrayCreationTimeNanos / n)
|
|
||||||
.bandwidthDeviceToHost(bandwidthDeviceToHost / n)
|
|
||||||
.diskReadingTimeNanos(diskReadingTimeNanos / n)
|
|
||||||
.bandwidthNanosHostToDevice(bandwidthNanosHostToDevice / n)
|
|
||||||
.build();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,60 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.datavec.datavec.timing;
|
|
||||||
|
|
||||||
import org.datavec.api.records.reader.RecordReader;
|
|
||||||
import org.datavec.api.writable.NDArrayWritable;
|
|
||||||
import org.datavec.api.writable.Writable;
|
|
||||||
import org.datavec.image.loader.NativeImageLoader;
|
|
||||||
import org.datavec.image.recordreader.ImageRecordReader;
|
|
||||||
import org.datavec.perf.timing.IOTiming;
|
|
||||||
import org.datavec.perf.timing.TimingStatistics;
|
|
||||||
import org.junit.Test;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.io.ClassPathResource;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class IOTimingTest {
|
|
||||||
|
|
||||||
@Test
|
|
||||||
public void testTiming() throws Exception {
|
|
||||||
final RecordReader image = new ImageRecordReader(28,28);
|
|
||||||
final NativeImageLoader nativeImageLoader = new NativeImageLoader(28,28);
|
|
||||||
|
|
||||||
TimingStatistics timingStatistics = IOTiming.timeNDArrayCreation(image, new ClassPathResource("datavec-perf/largestblobtest.jpg").getInputStream(), new IOTiming.INDArrayCreationFunction() {
|
|
||||||
@Override
|
|
||||||
public INDArray createFromRecord(List<Writable> record) {
|
|
||||||
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
|
|
||||||
return imageWritable.get();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
System.out.println(timingStatistics);
|
|
||||||
|
|
||||||
TimingStatistics timingStatistics1 = IOTiming.averageFileRead(1000,image,new ClassPathResource("datavec-perf/largestblobtest.jpg").getFile(), new IOTiming.INDArrayCreationFunction() {
|
|
||||||
@Override
|
|
||||||
public INDArray createFromRecord(List<Writable> record) {
|
|
||||||
NDArrayWritable imageWritable = (NDArrayWritable) record.get(0);
|
|
||||||
return imageWritable.get();
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
System.out.println(timingStatistics1);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -59,16 +59,12 @@
|
||||||
<modules>
|
<modules>
|
||||||
<module>datavec-api</module>
|
<module>datavec-api</module>
|
||||||
<module>datavec-data</module>
|
<module>datavec-data</module>
|
||||||
<module>datavec-geo</module>
|
|
||||||
<module>datavec-hadoop</module>
|
|
||||||
<module>datavec-spark</module>
|
<module>datavec-spark</module>
|
||||||
<module>datavec-camel</module>
|
|
||||||
<module>datavec-local</module>
|
<module>datavec-local</module>
|
||||||
<module>datavec-spark-inference-parent</module>
|
<module>datavec-spark-inference-parent</module>
|
||||||
<module>datavec-jdbc</module>
|
<module>datavec-jdbc</module>
|
||||||
<module>datavec-excel</module>
|
<module>datavec-excel</module>
|
||||||
<module>datavec-arrow</module>
|
<module>datavec-arrow</module>
|
||||||
<module>datavec-perf</module>
|
|
||||||
<module>datavec-python</module>
|
<module>datavec-python</module>
|
||||||
</modules>
|
</modules>
|
||||||
|
|
||||||
|
|
|
@ -153,6 +153,17 @@
|
||||||
<version>${jaxb.version}</version>
|
<version>${jaxb.version}</version>
|
||||||
<scope>provided</scope>
|
<scope>provided</scope>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<!-- Dependencies for dl4j-perf subproject -->
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.oshi</groupId>
|
||||||
|
<artifactId>oshi-json</artifactId>
|
||||||
|
<version>${oshi.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>com.github.oshi</groupId>
|
||||||
|
<artifactId>oshi-core</artifactId>
|
||||||
|
<version>${oshi.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -14,11 +14,13 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solver.accumulation;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.EncodingHandler;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.threshold.FixedThresholdAlgorithm;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||||||
|
@ -68,8 +70,8 @@ public class EncodedGradientsAccumulatorTest extends BaseDL4JTest {
|
||||||
accumulator.receiveUpdate(encoded);
|
accumulator.receiveUpdate(encoded);
|
||||||
|
|
||||||
// just purge updates, like they were consumed
|
// just purge updates, like they were consumed
|
||||||
for (int i = 0; i < accumulator.messages.size(); i++) {
|
for (int i = 0; i < accumulator.getMessages().size(); i++) {
|
||||||
accumulator.messages.get(i).clear();
|
accumulator.getMessages().get(i).clear();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
|
@ -14,12 +14,13 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solver.accumulation;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.IndexedTail;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
import org.nd4j.linalg.factory.Nd4j;
|
||||||
|
@ -97,9 +98,9 @@ public class IndexedTailTest extends BaseDL4JTest {
|
||||||
assertEquals(-1, tail.maxAppliedIndexEverywhere());
|
assertEquals(-1, tail.maxAppliedIndexEverywhere());
|
||||||
|
|
||||||
// 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements
|
// 2 consumers consumed 2 elements, and 1 consumer consumed 3 elements
|
||||||
tail.positions.get(11L).set(2);
|
tail.getPositions().get(11L).set(2);
|
||||||
tail.positions.get(22L).set(2);
|
tail.getPositions().get(22L).set(2);
|
||||||
tail.positions.get(33L).set(3);
|
tail.getPositions().get(33L).set(3);
|
||||||
|
|
||||||
// all elements including this index are safe to remove, because they were consumed everywhere
|
// all elements including this index are safe to remove, because they were consumed everywhere
|
||||||
assertEquals(2, tail.maxAppliedIndexEverywhere());
|
assertEquals(2, tail.maxAppliedIndexEverywhere());
|
||||||
|
@ -197,10 +198,10 @@ public class IndexedTailTest extends BaseDL4JTest {
|
||||||
Nd4j.getExecutioner().commit();
|
Nd4j.getExecutioner().commit();
|
||||||
}
|
}
|
||||||
|
|
||||||
assertTrue(tail.collapsedMode.get());
|
assertTrue(tail.getCollapsedMode().get());
|
||||||
assertEquals(1, tail.updatesSize());
|
assertEquals(1, tail.updatesSize());
|
||||||
|
|
||||||
val array = tail.updates.get(32L);
|
val array = tail.getUpdates().get(32L);
|
||||||
assertNotNull(array);
|
assertNotNull(array);
|
||||||
assertEquals(sum, (int) array.getDouble(0));
|
assertEquals(sum, (int) array.getDouble(0));
|
||||||
}
|
}
|
|
@ -14,12 +14,13 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solver.accumulation;
|
||||||
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
import org.apache.commons.lang3.RandomUtils;
|
import org.apache.commons.lang3.RandomUtils;
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
|
import org.deeplearning4j.optimize.solvers.accumulation.SmartFancyBlockingQueue;
|
||||||
import org.deeplearning4j.util.ThreadUtils;
|
import org.deeplearning4j.util.ThreadUtils;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
|
@ -14,7 +14,7 @@
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
* SPDX-License-Identifier: Apache-2.0
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solver.accumulation;
|
||||||
|
|
||||||
import org.deeplearning4j.BaseDL4JTest;
|
import org.deeplearning4j.BaseDL4JTest;
|
||||||
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
|
@ -1,5 +1,6 @@
|
||||||
package org.deeplearning4j.optimize.listeners;
|
package org.deeplearning4j.optimizer.listener;
|
||||||
|
|
||||||
|
import org.deeplearning4j.optimize.listeners.CollectScoresIterationListener;
|
||||||
import org.junit.Ignore;
|
import org.junit.Ignore;
|
||||||
import org.junit.Test;
|
import org.junit.Test;
|
||||||
|
|
|
@ -305,7 +305,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
|
||||||
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
|
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
|
||||||
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
|
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, excludeParams);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, false, -1, excludeParams, null);
|
||||||
|
|
||||||
assertTrue(gradOK);
|
assertTrue(gradOK);
|
||||||
}
|
}
|
||||||
|
@ -417,7 +417,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, input, labels, null, null, true, 32, null, null);
|
||||||
|
|
||||||
assertTrue(gradOK);
|
assertTrue(gradOK);
|
||||||
}
|
}
|
||||||
|
@ -655,9 +655,12 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
|
||||||
};
|
};
|
||||||
|
|
||||||
log.info("*** Starting test: " + msg + " ***");
|
log.info("*** Starting test: " + msg + " ***");
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(mln, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, f, l, null, null,
|
new GradientCheckUtil.MLNConfig().net(mln).epsilon(DEFAULT_EPS)
|
||||||
false, -1, null, c);
|
.maxRelError(DEFAULT_MAX_REL_ERROR).minAbsoluteError(DEFAULT_MIN_ABS_ERROR)
|
||||||
|
.print(PRINT_RESULTS ? GradientCheckUtil.PrintMode.ZEROS : GradientCheckUtil.PrintMode.FAILURES_ONLY)
|
||||||
|
.exitOnFirstError(RETURN_ON_FIRST_FAILURE)
|
||||||
|
.input(f).labels(l).callEachIter(c));
|
||||||
|
|
||||||
assertTrue(msg, gradOK);
|
assertTrue(msg, gradOK);
|
||||||
TestUtils.testModelSerialization(mln);
|
TestUtils.testModelSerialization(mln);
|
||||||
|
@ -691,7 +694,7 @@ public class CuDNNGradientChecks extends BaseDL4JTest {
|
||||||
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
|
//However, numerical gradient will be 0 as forward pass doesn't depend on this "parameter"
|
||||||
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
|
Set<String> excludeParams = new HashSet<>(Arrays.asList("1_mean", "1_var", "1_log10stdev"));
|
||||||
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
boolean gradOK = GradientCheckUtil.checkGradients(net, DEFAULT_EPS, DEFAULT_MAX_REL_ERROR,
|
||||||
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, excludeParams);
|
DEFAULT_MIN_ABS_ERROR, PRINT_RESULTS, RETURN_ON_FIRST_FAILURE, in, labels, null, null, false, -1, excludeParams, null);
|
||||||
|
|
||||||
assertTrue(gradOK);
|
assertTrue(gradOK);
|
||||||
|
|
||||||
|
|
|
@ -63,6 +63,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
protected int parties;
|
protected int parties;
|
||||||
@Getter
|
@Getter
|
||||||
protected MessageHandler handler;
|
protected MessageHandler handler;
|
||||||
|
@Getter
|
||||||
protected List<BlockingQueue<INDArray>> messages = new ArrayList<>();
|
protected List<BlockingQueue<INDArray>> messages = new ArrayList<>();
|
||||||
protected List<MemoryWorkspace> workspaces = new ArrayList<>();
|
protected List<MemoryWorkspace> workspaces = new ArrayList<>();
|
||||||
protected List<ReentrantLock> locks = new ArrayList<>();
|
protected List<ReentrantLock> locks = new ArrayList<>();
|
||||||
|
@ -106,7 +107,7 @@ public class EncodedGradientsAccumulator implements GradientsAccumulator, Regist
|
||||||
this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode);
|
this(parties, new EncodingHandler(thresholdAlgorithm, residualPostProcessor, 1.0, encodingDebugMode), DEFAULT_INITIAL_MEMORY, 10, 1.0, encodingDebugMode);
|
||||||
}
|
}
|
||||||
|
|
||||||
protected EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
|
public EncodedGradientsAccumulator(int parties, @NonNull MessageHandler handler, long initialMemory,
|
||||||
int queueSize, Double boundary, boolean encodingDebugMode) {
|
int queueSize, Double boundary, boolean encodingDebugMode) {
|
||||||
this.parties = parties;
|
this.parties = parties;
|
||||||
this.handler = handler;
|
this.handler = handler;
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
package org.deeplearning4j.optimize.solvers.accumulation;
|
package org.deeplearning4j.optimize.solvers.accumulation;
|
||||||
|
|
||||||
|
import lombok.Getter;
|
||||||
import lombok.NonNull;
|
import lombok.NonNull;
|
||||||
import lombok.extern.slf4j.Slf4j;
|
import lombok.extern.slf4j.Slf4j;
|
||||||
import lombok.val;
|
import lombok.val;
|
||||||
|
@ -44,9 +45,11 @@ import java.util.concurrent.locks.ReentrantReadWriteLock;
|
||||||
@Slf4j
|
@Slf4j
|
||||||
public class IndexedTail {
|
public class IndexedTail {
|
||||||
// here we store positions of individual consumers
|
// here we store positions of individual consumers
|
||||||
|
@Getter
|
||||||
protected ConcurrentHashMap<Long, AtomicLong> positions = new ConcurrentHashMap<>();
|
protected ConcurrentHashMap<Long, AtomicLong> positions = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
// here we store individual updates
|
// here we store individual updates
|
||||||
|
@Getter
|
||||||
protected Map<Long, INDArray> updates = new ConcurrentHashMap<>();
|
protected Map<Long, INDArray> updates = new ConcurrentHashMap<>();
|
||||||
|
|
||||||
// simple counter for new updates
|
// simple counter for new updates
|
||||||
|
@ -67,6 +70,7 @@ public class IndexedTail {
|
||||||
protected final boolean allowCollapse;
|
protected final boolean allowCollapse;
|
||||||
protected final long[] shape;
|
protected final long[] shape;
|
||||||
protected final int collapseThreshold = 32;
|
protected final int collapseThreshold = 32;
|
||||||
|
@Getter
|
||||||
protected AtomicBoolean collapsedMode = new AtomicBoolean(false);
|
protected AtomicBoolean collapsedMode = new AtomicBoolean(false);
|
||||||
protected AtomicLong collapsedIndex = new AtomicLong(-1);
|
protected AtomicLong collapsedIndex = new AtomicLong(-1);
|
||||||
|
|
||||||
|
@ -148,7 +152,7 @@ public class IndexedTail {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
protected long firstNotAppliedIndexEverywhere() {
|
public long firstNotAppliedIndexEverywhere() {
|
||||||
long maxIdx = -1;
|
long maxIdx = -1;
|
||||||
|
|
||||||
// if there's no updates posted yet - just return negative value
|
// if there's no updates posted yet - just return negative value
|
||||||
|
@ -163,7 +167,7 @@ public class IndexedTail {
|
||||||
return maxIdx + 1;
|
return maxIdx + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected long maxAppliedIndexEverywhere() {
|
public long maxAppliedIndexEverywhere() {
|
||||||
long maxIdx = Long.MAX_VALUE;
|
long maxIdx = Long.MAX_VALUE;
|
||||||
for (val v:positions.values()) {
|
for (val v:positions.values()) {
|
||||||
if (v.get() < maxIdx)
|
if (v.get() < maxIdx)
|
||||||
|
@ -212,7 +216,7 @@ public class IndexedTail {
|
||||||
return getDelta(Thread.currentThread().getId());
|
return getDelta(Thread.currentThread().getId());
|
||||||
}
|
}
|
||||||
|
|
||||||
protected long getDelta(long threadId) {
|
public long getDelta(long threadId) {
|
||||||
return getGlobalPosition() - getLocalPosition(threadId);
|
return getGlobalPosition() - getLocalPosition(threadId);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -293,7 +297,7 @@ public class IndexedTail {
|
||||||
/**
|
/**
|
||||||
* This method does maintenance of updates within
|
* This method does maintenance of updates within
|
||||||
*/
|
*/
|
||||||
protected synchronized void maintenance() {
|
public synchronized void maintenance() {
|
||||||
// first of all we're checking, if all consumers were already registered. if not - just no-op.
|
// first of all we're checking, if all consumers were already registered. if not - just no-op.
|
||||||
if (positions.size() < expectedConsumers) {
|
if (positions.size() < expectedConsumers) {
|
||||||
log.trace("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", positions.size(), expectedConsumers);
|
log.trace("Skipping maintanance due to not all expected consumers shown up: [{}] vs [{}]", positions.size(), expectedConsumers);
|
||||||
|
@ -326,7 +330,7 @@ public class IndexedTail {
|
||||||
* This method returns actual number of updates stored within tail
|
* This method returns actual number of updates stored within tail
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
protected int updatesSize() {
|
public int updatesSize() {
|
||||||
return updates.size();
|
return updates.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -348,11 +352,11 @@ public class IndexedTail {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected boolean isDead() {
|
public boolean isDead() {
|
||||||
return dead.get();
|
return dead.get();
|
||||||
}
|
}
|
||||||
|
|
||||||
protected void notifyDead() {
|
public void notifyDead() {
|
||||||
dead.set(true);
|
dead.set(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,158 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j;
|
|
||||||
|
|
||||||
import org.apache.commons.compress.utils.IOUtils;
|
|
||||||
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
|
|
||||||
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
|
|
||||||
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
||||||
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
||||||
import org.deeplearning4j.util.ModelSerializer;
|
|
||||||
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
||||||
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
|
|
||||||
import org.nd4j.linalg.factory.Nd4j;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.Random;
|
|
||||||
|
|
||||||
import static org.junit.Assert.assertEquals;
|
|
||||||
|
|
||||||
public class TestUtils {
|
|
||||||
|
|
||||||
public static MultiLayerNetwork testModelSerialization(MultiLayerNetwork net){
|
|
||||||
|
|
||||||
MultiLayerNetwork restored;
|
|
||||||
try {
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
restored = ModelSerializer.restoreMultiLayerNetwork(bais, true);
|
|
||||||
|
|
||||||
assertEquals(net.getLayerWiseConfigurations(), restored.getLayerWiseConfigurations());
|
|
||||||
assertEquals(net.params(), restored.params());
|
|
||||||
} catch (IOException e){
|
|
||||||
//Should never happen
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Also check the MultiLayerConfiguration is serializable (required by Spark etc)
|
|
||||||
MultiLayerConfiguration conf = net.getLayerWiseConfigurations();
|
|
||||||
serializeDeserializeJava(conf);
|
|
||||||
|
|
||||||
return restored;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static ComputationGraph testModelSerialization(ComputationGraph net){
|
|
||||||
|
|
||||||
ComputationGraph restored;
|
|
||||||
try {
|
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
|
||||||
ModelSerializer.writeModel(net, baos, true);
|
|
||||||
byte[] bytes = baos.toByteArray();
|
|
||||||
|
|
||||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
|
||||||
restored = ModelSerializer.restoreComputationGraph(bais, true);
|
|
||||||
|
|
||||||
assertEquals(net.getConfiguration(), restored.getConfiguration());
|
|
||||||
assertEquals(net.params(), restored.params());
|
|
||||||
} catch (IOException e){
|
|
||||||
//Should never happen
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
//Also check the ComputationGraphConfiguration is serializable (required by Spark etc)
|
|
||||||
ComputationGraphConfiguration conf = net.getConfiguration();
|
|
||||||
serializeDeserializeJava(conf);
|
|
||||||
|
|
||||||
return restored;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static <T> T serializeDeserializeJava(T object){
|
|
||||||
byte[] bytes;
|
|
||||||
try(ByteArrayOutputStream baos = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(baos)){
|
|
||||||
oos.writeObject(object);
|
|
||||||
oos.close();
|
|
||||||
bytes = baos.toByteArray();
|
|
||||||
} catch (IOException e){
|
|
||||||
//Should never happen
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
T out;
|
|
||||||
try(ObjectInputStream ois = new ObjectInputStream(new ByteArrayInputStream(bytes))){
|
|
||||||
out = (T)ois.readObject();
|
|
||||||
} catch (IOException | ClassNotFoundException e){
|
|
||||||
throw new RuntimeException(e);
|
|
||||||
}
|
|
||||||
|
|
||||||
assertEquals(object, out);
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHot(long examples, long nOut){
|
|
||||||
return randomOneHot(examples, nOut, new Random(12345));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHot(long examples, long nOut, long rngSeed){
|
|
||||||
return randomOneHot(examples, nOut, new Random(rngSeed));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHot(long examples, long nOut, Random rng){
|
|
||||||
INDArray arr = Nd4j.create(examples, nOut);
|
|
||||||
for( int i=0; i<examples; i++ ){
|
|
||||||
arr.putScalar(i, rng.nextInt((int) nOut), 1.0);
|
|
||||||
}
|
|
||||||
return arr;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength){
|
|
||||||
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random());
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, long rngSeed){
|
|
||||||
return randomOneHotTimeSeries(minibatch, outSize, tsLength, new Random(rngSeed));
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomOneHotTimeSeries(int minibatch, int outSize, int tsLength, Random rng){
|
|
||||||
INDArray out = Nd4j.create(new int[]{minibatch, outSize, tsLength}, 'f');
|
|
||||||
for( int i=0; i<minibatch; i++ ){
|
|
||||||
for( int j=0; j<tsLength; j++ ){
|
|
||||||
out.putScalar(i, rng.nextInt(outSize), j, 1.0);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomBernoulli(int... shape) {
|
|
||||||
return randomBernoulli(0.5, shape);
|
|
||||||
}
|
|
||||||
|
|
||||||
public static INDArray randomBernoulli(double p, int... shape){
|
|
||||||
INDArray ret = Nd4j.createUninitialized(shape);
|
|
||||||
Nd4j.getExecutioner().exec(new BernoulliDistribution(ret, p));
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
public static void writeStreamToFile(File out, InputStream is) throws IOException {
|
|
||||||
byte[] b = IOUtils.toByteArray(is);
|
|
||||||
try (OutputStream os = new BufferedOutputStream(new FileOutputStream(out))) {
|
|
||||||
os.write(b);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,99 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
<parent>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-scaleout</artifactId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-aws_2.11</artifactId>
|
|
||||||
<name>DeepLearning4j-AWS</name>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<configuration>
|
|
||||||
<source>1.8</source>
|
|
||||||
<target>1.8</target>
|
|
||||||
</configuration>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</build>
|
|
||||||
<properties>
|
|
||||||
<!-- Default scala versions, may be overwritten by build profiles -->
|
|
||||||
<scala.version>2.11.12</scala.version>
|
|
||||||
<scala.binary.version>2.11</scala.binary.version>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.amazonaws</groupId>
|
|
||||||
<artifactId>aws-java-sdk</artifactId>
|
|
||||||
<version>1.11.24</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>args4j</groupId>
|
|
||||||
<artifactId>args4j</artifactId>
|
|
||||||
<version>2.32</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-util</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.jcraft</groupId>
|
|
||||||
<artifactId>jsch</artifactId>
|
|
||||||
<version>${jsch.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.threadly</groupId>
|
|
||||||
<artifactId>threadly</artifactId>
|
|
||||||
<version>${threadly.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.apache.commons</groupId>
|
|
||||||
<artifactId>commons-lang3</artifactId>
|
|
||||||
<version>${commons-lang3.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,34 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.dataset;
|
|
||||||
|
|
||||||
import org.deeplearning4j.aws.s3.reader.S3Downloader;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
public class DataSetLoader {
|
|
||||||
|
|
||||||
private String bucket;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public void onData(InputStream is) {
|
|
||||||
S3Downloader downloader = new S3Downloader();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,222 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.ec2;
|
|
||||||
|
|
||||||
import com.amazonaws.regions.Regions;
|
|
||||||
import com.amazonaws.services.ec2.AmazonEC2;
|
|
||||||
import com.amazonaws.services.ec2.model.*;
|
|
||||||
import org.deeplearning4j.aws.s3.BaseS3;
|
|
||||||
import org.deeplearning4j.util.ThreadUtils;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates Ec2Boxes
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class Ec2BoxCreator extends BaseS3 {
|
|
||||||
|
|
||||||
|
|
||||||
private String amiId;
|
|
||||||
private int numBoxes;
|
|
||||||
private String size;
|
|
||||||
private List<String> boxesCreated;
|
|
||||||
private String securityGroupId;
|
|
||||||
private String keyPair;
|
|
||||||
private Regions regions = Regions.DEFAULT_REGION;
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(Ec2BoxCreator.class);
|
|
||||||
|
|
||||||
//centos
|
|
||||||
public final static String DEFAULT_AMI = "ami-8997afe0";
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param numBoxes number of boxes
|
|
||||||
* @param size the size of the instances
|
|
||||||
*/
|
|
||||||
public Ec2BoxCreator(int numBoxes, String size, String securityGroupId, String keyPair) {
|
|
||||||
this(DEFAULT_AMI, numBoxes, size, securityGroupId, keyPair);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param amiId amazon image id
|
|
||||||
* @param numBoxes number of boxes
|
|
||||||
* @param size the size of the instances
|
|
||||||
* @param securityGroupId
|
|
||||||
*/
|
|
||||||
public Ec2BoxCreator(String amiId, int numBoxes, String size, String securityGroupId, String keyPair) {
|
|
||||||
super();
|
|
||||||
this.amiId = amiId;
|
|
||||||
this.numBoxes = numBoxes;
|
|
||||||
this.size = size;
|
|
||||||
this.keyPair = keyPair;
|
|
||||||
this.securityGroupId = securityGroupId;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void createSpot() {
|
|
||||||
// Initializes a Spot Instance Request
|
|
||||||
RequestSpotInstancesRequest requestRequest = new RequestSpotInstancesRequest();
|
|
||||||
|
|
||||||
// Request 1 x t1.micro instance with a bid price of $0.03.
|
|
||||||
requestRequest.setSpotPrice("0.03");
|
|
||||||
requestRequest.setInstanceCount(Integer.valueOf(1));
|
|
||||||
|
|
||||||
// Setup the specifications of the launch. This includes the
|
|
||||||
// instance type (e.g. t1.micro) and the latest Amazon Linux
|
|
||||||
// AMI id available. Note, you should always use the latest
|
|
||||||
// Amazon Linux AMI id or another of your choosing.
|
|
||||||
LaunchSpecification launchSpecification = new LaunchSpecification();
|
|
||||||
launchSpecification.setImageId("ami-8c1fece5");
|
|
||||||
launchSpecification.setInstanceType("t1.micro");
|
|
||||||
|
|
||||||
// Add the security group to the request.
|
|
||||||
List<String> securityGroups = new ArrayList<>();
|
|
||||||
securityGroups.add("GettingStartedGroup");
|
|
||||||
launchSpecification.setSecurityGroups(securityGroups);
|
|
||||||
|
|
||||||
// Add the launch specifications to the request.
|
|
||||||
requestRequest.setLaunchSpecification(launchSpecification);
|
|
||||||
|
|
||||||
// Call the RequestSpotInstance API.
|
|
||||||
RequestSpotInstancesResult requestResult = getEc2().requestSpotInstances(requestRequest);
|
|
||||||
|
|
||||||
|
|
||||||
List<SpotInstanceRequest> requestResponses = requestResult.getSpotInstanceRequests();
|
|
||||||
|
|
||||||
// Setup an arraylist to collect all of the request ids we want to
|
|
||||||
// watch hit the running state.
|
|
||||||
List<String> spotInstanceRequestIds = new ArrayList<>();
|
|
||||||
|
|
||||||
// Add all of the request ids to the hashset, so we can determine when they hit the
|
|
||||||
// active state.
|
|
||||||
for (SpotInstanceRequest requestResponse : requestResponses) {
|
|
||||||
System.out.println("Created Spot Request: " + requestResponse.getSpotInstanceRequestId());
|
|
||||||
spotInstanceRequestIds.add(requestResponse.getSpotInstanceRequestId());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setRegion(Regions regions) {
|
|
||||||
this.regions = regions;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Create the instances
|
|
||||||
*/
|
|
||||||
public void create() {
|
|
||||||
RunInstancesRequest runInstancesRequest =
|
|
||||||
new RunInstancesRequest().withImageId(amiId).withInstanceType(size).withKeyName(keyPair)
|
|
||||||
.withMinCount(1).withSecurityGroupIds(securityGroupId).withMaxCount(numBoxes);
|
|
||||||
AmazonEC2 ec2 = getEc2();
|
|
||||||
ec2.setRegion(com.amazonaws.regions.Region.getRegion(regions));
|
|
||||||
List<Instance> boxes = ec2.runInstances(runInstancesRequest).getReservation().getInstances();
|
|
||||||
if (boxesCreated == null) {
|
|
||||||
boxesCreated = new ArrayList<>();
|
|
||||||
for (Instance i : boxes)
|
|
||||||
boxesCreated.add(i.getInstanceId());
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
log.info("Boxes created " + boxesCreated);
|
|
||||||
} else {
|
|
||||||
blowupBoxes();
|
|
||||||
boxesCreated.clear();
|
|
||||||
for (Instance i : boxes)
|
|
||||||
boxesCreated.add(i.getInstanceId());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public List<InstanceStateChange> blowupBoxes() {
|
|
||||||
TerminateInstancesRequest request = new TerminateInstancesRequest().withInstanceIds(boxesCreated);
|
|
||||||
|
|
||||||
if (boxesCreated != null) {
|
|
||||||
TerminateInstancesResult result = getEc2().terminateInstances(request);
|
|
||||||
List<InstanceStateChange> change = result.getTerminatingInstances();
|
|
||||||
log.info("Boxes destroyed " + boxesCreated);
|
|
||||||
return change;
|
|
||||||
}
|
|
||||||
|
|
||||||
return Collections.emptyList();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void blockTillAllRunning() {
|
|
||||||
while (!allRunning()) {
|
|
||||||
ThreadUtils.uncheckedSleep(1000);
|
|
||||||
log.info("Not all created...");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean allRunning() {
|
|
||||||
if (boxesCreated == null)
|
|
||||||
return false;
|
|
||||||
else {
|
|
||||||
DescribeInstancesResult result = getEc2().describeInstances();
|
|
||||||
List<Reservation> reservations = result.getReservations();
|
|
||||||
for (Reservation r : reservations) {
|
|
||||||
List<Instance> instances = r.getInstances();
|
|
||||||
for (Instance instance : instances) {
|
|
||||||
InstanceState state = instance.getState();
|
|
||||||
if (state.getCode() == 48)
|
|
||||||
continue;
|
|
||||||
if (state.getCode() != 16)
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getHosts() {
|
|
||||||
DescribeInstancesResult result = getEc2().describeInstances();
|
|
||||||
List<String> hosts = new ArrayList<>();
|
|
||||||
List<Reservation> reservations = result.getReservations();
|
|
||||||
for (Reservation r : reservations) {
|
|
||||||
List<Instance> instances = r.getInstances();
|
|
||||||
for (Instance instance : instances) {
|
|
||||||
InstanceState state = instance.getState();
|
|
||||||
if (state.getCode() == 48)
|
|
||||||
continue;
|
|
||||||
hosts.add(instance.getPublicDnsName());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return hosts;
|
|
||||||
}
|
|
||||||
|
|
||||||
public List<String> getBoxesCreated() {
|
|
||||||
return boxesCreated;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,122 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.ec2.provision;
|
|
||||||
|
|
||||||
import com.amazonaws.regions.Regions;
|
|
||||||
import org.deeplearning4j.aws.ec2.Ec2BoxCreator;
|
|
||||||
import org.kohsuke.args4j.CmdLineException;
|
|
||||||
import org.kohsuke.args4j.CmdLineParser;
|
|
||||||
import org.kohsuke.args4j.Option;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
import org.threadly.concurrent.PriorityScheduler;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets up a DL4J cluster
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class ClusterSetup {
|
|
||||||
|
|
||||||
@Option(name = "-w", usage = "Number of workers")
|
|
||||||
private int numWorkers = 1;
|
|
||||||
@Option(name = "-ami", usage = "Amazon machine image: default, amazon linux (only works with RHEL right now")
|
|
||||||
private String ami = "ami-fb8e9292";
|
|
||||||
@Option(name = "-s", usage = "size of instance: default m1.medium")
|
|
||||||
private String size = "m3.xlarge";
|
|
||||||
@Option(name = "-sg", usage = "security group, this needs to be applyTransformToDestination")
|
|
||||||
private String securityGroupName;
|
|
||||||
@Option(name = "-kp", usage = "key pair name, also needs to be applyTransformToDestination.")
|
|
||||||
private String keyPairName;
|
|
||||||
@Option(name = "-kpath",
|
|
||||||
usage = "path to private key - needs to be applyTransformToDestination, this is used to login to amazon.")
|
|
||||||
private String pathToPrivateKey;
|
|
||||||
@Option(name = "-wscript", usage = "path to worker script to run, this will allow customization of dependencies")
|
|
||||||
private String workerSetupScriptPath;
|
|
||||||
@Option(name = "-mscript", usage = "path to master script to run this will allow customization of the dependencies")
|
|
||||||
private String masterSetupScriptPath;
|
|
||||||
@Option(name = "-region", usage = "specify a region")
|
|
||||||
private String region = Regions.US_EAST_1.getName();
|
|
||||||
|
|
||||||
private PriorityScheduler as;
|
|
||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(ClusterSetup.class);
|
|
||||||
|
|
||||||
|
|
||||||
public ClusterSetup(String[] args) {
|
|
||||||
CmdLineParser parser = new CmdLineParser(this);
|
|
||||||
try {
|
|
||||||
parser.parseArgument(args);
|
|
||||||
} catch (CmdLineException e) {
|
|
||||||
parser.printUsage(System.err);
|
|
||||||
log.error("Unable to parse args", e);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public void exec() {
|
|
||||||
//master + workers
|
|
||||||
Ec2BoxCreator boxCreator = new Ec2BoxCreator(ami, numWorkers, size, securityGroupName, keyPairName);
|
|
||||||
boxCreator.setRegion(Regions.fromName(region));
|
|
||||||
boxCreator.create();
|
|
||||||
boxCreator.blockTillAllRunning();
|
|
||||||
List<String> hosts = boxCreator.getHosts();
|
|
||||||
//provisionMaster(hosts.get(0));
|
|
||||||
provisionWorkers(hosts);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private void provisionWorkers(List<String> workers) {
|
|
||||||
as = new PriorityScheduler(Runtime.getRuntime().availableProcessors());
|
|
||||||
for (final String workerHost : workers) {
|
|
||||||
try {
|
|
||||||
as.execute(new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
HostProvisioner uploader = new HostProvisioner(workerHost, "ec2-user");
|
|
||||||
try {
|
|
||||||
uploader.addKeyFile(pathToPrivateKey);
|
|
||||||
//uploader.runRemoteCommand("sudo hostname " + workerHost);
|
|
||||||
uploader.uploadAndRun(workerSetupScriptPath, "");
|
|
||||||
} catch (Exception e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.error("Error ", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param args
|
|
||||||
*/
|
|
||||||
public static void main(String[] args) {
|
|
||||||
new ClusterSetup(args).exec();
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.ec2.provision;
|
|
||||||
|
|
||||||
public class DistributedDeepLearningTrainer {
|
|
||||||
|
|
||||||
private DistributedDeepLearningTrainer() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @param args
|
|
||||||
*/
|
|
||||||
public static void main(String[] args) {
|
|
||||||
ClusterSetup clusterSet = new ClusterSetup(args);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,311 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.ec2.provision;
|
|
||||||
|
|
||||||
import com.jcraft.jsch.*;
|
|
||||||
import org.slf4j.Logger;
|
|
||||||
import org.slf4j.LoggerFactory;
|
|
||||||
|
|
||||||
import java.io.BufferedInputStream;
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.FileInputStream;
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Collection;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Meant for uploading files to remote servers
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class HostProvisioner implements UserInfo {
|
|
||||||
|
|
||||||
private String host;
|
|
||||||
private JSch jsch;
|
|
||||||
private String user;
|
|
||||||
private int port = 22;
|
|
||||||
private String password;
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(HostProvisioner.class);
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param host host to connect to (public facing dns)
|
|
||||||
* @param user the user to connect with (default root otherwise)
|
|
||||||
* @param password the password to use if any
|
|
||||||
* @param port the port to connect to(default 22)
|
|
||||||
*/
|
|
||||||
public HostProvisioner(String host, String user, String password, int port) {
|
|
||||||
super();
|
|
||||||
this.host = host;
|
|
||||||
this.user = user;
|
|
||||||
this.port = port;
|
|
||||||
this.password = password;
|
|
||||||
jsch = new JSch();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connects to port 22
|
|
||||||
* @param host host to connect to (public facing dns)
|
|
||||||
* @param user the user to connect with (default root otherwise)
|
|
||||||
* @param password the password to use if any
|
|
||||||
*/
|
|
||||||
public HostProvisioner(String host, String user, String password) {
|
|
||||||
this(host, user, password, 22);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connects to port 22
|
|
||||||
* @param host host to connect to (public facing dns)
|
|
||||||
* @param user the user to connect with (default root otherwise)
|
|
||||||
*/
|
|
||||||
public HostProvisioner(String host, String user) {
|
|
||||||
this(host, user, "", 22);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Connects to port 22, user root, with no password
|
|
||||||
* @param host host to connect to (public facing dns)
|
|
||||||
*/
|
|
||||||
public HostProvisioner(String host) {
|
|
||||||
this(host, "root", "", 22);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public void uploadAndRun(String script, String rootDir) throws Exception {
|
|
||||||
String remoteName = rootDir.isEmpty() ? new File(script).getName() : rootDir + "/" + new File(script).getName();
|
|
||||||
upload(new File(script), remoteName);
|
|
||||||
|
|
||||||
String remoteCommand = remoteName.charAt(0) != '/' ? "./" + remoteName : remoteName;
|
|
||||||
remoteCommand = "chmod +x " + remoteCommand + " && " + remoteCommand;
|
|
||||||
runRemoteCommand(remoteCommand);
|
|
||||||
}
|
|
||||||
|
|
||||||
public void runRemoteCommand(String remoteCommand) throws Exception {
|
|
||||||
Session session = getSession();
|
|
||||||
session.connect();
|
|
||||||
ChannelExec channel = (ChannelExec) session.openChannel("exec");
|
|
||||||
|
|
||||||
|
|
||||||
channel.setCommand(remoteCommand);
|
|
||||||
channel.setErrStream(System.err);
|
|
||||||
channel.setPty(true);
|
|
||||||
|
|
||||||
channel.setOutputStream(System.out);
|
|
||||||
channel.connect();
|
|
||||||
channel.start();
|
|
||||||
InputStream input = channel.getInputStream();
|
|
||||||
|
|
||||||
//start reading the input from the executed commands on the shell
|
|
||||||
byte[] tmp = new byte[60000];
|
|
||||||
while (true) {
|
|
||||||
while (input.available() > 0) {
|
|
||||||
int i = input.read(tmp, 0, tmp.length);
|
|
||||||
if (i < 0)
|
|
||||||
break;
|
|
||||||
log.info(new String(tmp, 0, i));
|
|
||||||
}
|
|
||||||
if (channel.isClosed()) {
|
|
||||||
log.info("exit-status: " + channel.getExitStatus());
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.disconnect();
|
|
||||||
session.disconnect();
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private Session getSession() throws Exception {
|
|
||||||
Session session = jsch.getSession(user, host, port);
|
|
||||||
session.setUserInfo(this);
|
|
||||||
return session;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates the directory for the file if necessary
|
|
||||||
* and uploads the file
|
|
||||||
* @param from the directory to upload from
|
|
||||||
* @param to the destination directory on the remote server
|
|
||||||
* @throws Exception
|
|
||||||
*/
|
|
||||||
public void uploadForDeployment(String from, String to) throws Exception {
|
|
||||||
File fromFile = new File(from);
|
|
||||||
if (!to.isEmpty() && fromFile.isDirectory())
|
|
||||||
mkDir(to);
|
|
||||||
else
|
|
||||||
upload(from, to);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addKeyFile(String keyFile) throws Exception {
|
|
||||||
jsch.addIdentity(keyFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
//creates the directory to upload to
|
|
||||||
private void mkDir(String dir) throws Exception {
|
|
||||||
Session session = getSession();
|
|
||||||
session.connect();
|
|
||||||
Channel channel = session.openChannel("sftp");
|
|
||||||
channel.connect();
|
|
||||||
|
|
||||||
ChannelSftp c = (ChannelSftp) channel;
|
|
||||||
if (!fileExists(dir, c))
|
|
||||||
c.mkdir(dir);
|
|
||||||
c.exit();
|
|
||||||
session.disconnect();
|
|
||||||
}
|
|
||||||
|
|
||||||
private boolean fileExists(String dir, ChannelSftp channel) {
|
|
||||||
try {
|
|
||||||
channel.stat(dir);
|
|
||||||
return true;
|
|
||||||
} catch (Exception e) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
//uploads the file or listed files in a directory
|
|
||||||
private void upload(String fileOrDir, String uploadRootDir) throws Exception {
|
|
||||||
if (uploadRootDir.isEmpty())
|
|
||||||
uploadRootDir = ".";
|
|
||||||
File origin = new File(fileOrDir);
|
|
||||||
|
|
||||||
if (fileOrDir.endsWith(".tar") || fileOrDir.endsWith(".tar.gz")) {
|
|
||||||
upload(new File(fileOrDir), uploadRootDir);
|
|
||||||
untar(uploadRootDir);
|
|
||||||
} else if (origin.isFile()) {
|
|
||||||
upload(new File(fileOrDir), uploadRootDir);
|
|
||||||
} else {
|
|
||||||
File[] childFiles = origin.listFiles();
|
|
||||||
if (childFiles != null)
|
|
||||||
upload(Arrays.asList(childFiles), uploadRootDir);
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private void untar(String targetRemoteFile) throws Exception {
|
|
||||||
this.runRemoteCommand("tar xvf " + targetRemoteFile);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void upload(Collection<File> files, String rootDir) throws Exception {
|
|
||||||
Session session = getSession();
|
|
||||||
session.connect();
|
|
||||||
Channel channel = session.openChannel("sftp");
|
|
||||||
channel.connect();
|
|
||||||
|
|
||||||
ChannelSftp c = (ChannelSftp) channel;
|
|
||||||
for (File f : files) {
|
|
||||||
if (f.isDirectory()) {
|
|
||||||
log.warn("Skipping " + f.getName());
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
log.info("Uploading " + f.getName());
|
|
||||||
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
|
|
||||||
c.put(bis, rootDir + "/" + f.getName());
|
|
||||||
bis.close();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
channel.disconnect();
|
|
||||||
session.disconnect();
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void upload(File f, String remoteFile) throws Exception {
|
|
||||||
Session session = getSession();
|
|
||||||
int numRetries = 0;
|
|
||||||
while (numRetries < 3 && !session.isConnected()) {
|
|
||||||
try {
|
|
||||||
session.connect();
|
|
||||||
} catch (Exception e) {
|
|
||||||
numRetries++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
try {
|
|
||||||
Channel channel = session.openChannel("sftp");
|
|
||||||
|
|
||||||
|
|
||||||
channel.connect();
|
|
||||||
|
|
||||||
ChannelSftp c = (ChannelSftp) channel;
|
|
||||||
|
|
||||||
BufferedInputStream bis = new BufferedInputStream(new FileInputStream(f));
|
|
||||||
if (this.fileExists(remoteFile, c))
|
|
||||||
if (f.isDirectory())
|
|
||||||
c.rmdir(remoteFile);
|
|
||||||
else
|
|
||||||
c.rm(remoteFile);
|
|
||||||
c.put(bis, remoteFile);
|
|
||||||
bis.close();
|
|
||||||
c.exit();
|
|
||||||
session.disconnect();
|
|
||||||
} catch (Exception e) {
|
|
||||||
log.info("Session was down...trying again", e);
|
|
||||||
upload(f, remoteFile);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getPassphrase() {
|
|
||||||
return this.password;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public String getPassword() {
|
|
||||||
return this.password;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean promptPassphrase(String arg0) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean promptPassword(String arg0) {
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean promptYesNo(String arg0) {
|
|
||||||
return true;
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void showMessage(String arg0) {
|
|
||||||
log.info(arg0);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,47 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.emr;
|
|
||||||
|
|
||||||
import com.amazonaws.services.elasticmapreduce.model.Configuration;
|
|
||||||
|
|
||||||
import lombok.*;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Map;
|
|
||||||
|
|
||||||
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Builder
|
|
||||||
public class EmrConfig {
|
|
||||||
|
|
||||||
protected String classification;
|
|
||||||
protected Map<String, String> properties;
|
|
||||||
protected List<EmrConfig> configs;
|
|
||||||
|
|
||||||
Configuration toAwsConfig() {
|
|
||||||
Configuration config = new Configuration().withClassification(classification).withProperties(properties);
|
|
||||||
List<Configuration> subConfigs = new ArrayList<>();
|
|
||||||
for (EmrConfig conf : configs){
|
|
||||||
subConfigs.add(conf.toAwsConfig());
|
|
||||||
}
|
|
||||||
return config.withConfigurations(subConfigs);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,531 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.emr;
|
|
||||||
|
|
||||||
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduce;
|
|
||||||
import com.amazonaws.services.elasticmapreduce.AmazonElasticMapReduceClientBuilder;
|
|
||||||
import com.amazonaws.services.elasticmapreduce.model.*;
|
|
||||||
import com.amazonaws.services.s3.AmazonS3ClientBuilder;
|
|
||||||
import com.amazonaws.services.s3.AmazonS3URI;
|
|
||||||
import com.amazonaws.services.s3.model.PutObjectRequest;
|
|
||||||
import lombok.AccessLevel;
|
|
||||||
import lombok.AllArgsConstructor;
|
|
||||||
import lombok.Data;
|
|
||||||
import lombok.NoArgsConstructor;
|
|
||||||
import lombok.extern.slf4j.Slf4j;
|
|
||||||
import org.apache.commons.lang3.RandomStringUtils;
|
|
||||||
import org.nd4j.linalg.function.Function;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.*;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Configuration for a Spark EMR cluster
|
|
||||||
*/
|
|
||||||
@Data
|
|
||||||
@AllArgsConstructor(access = AccessLevel.PRIVATE)
|
|
||||||
@NoArgsConstructor
|
|
||||||
@Slf4j
|
|
||||||
public class SparkEMRClient {
|
|
||||||
|
|
||||||
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
|
|
||||||
protected String sparkAwsRegion = "us-east-1";
|
|
||||||
protected String sparkEmrRelease = "emr-5.9.0";
|
|
||||||
protected String sparkEmrServiceRole = "EMR_DefaultRole";
|
|
||||||
protected List<EmrConfig> sparkEmrConfigs = Collections.emptyList();
|
|
||||||
protected String sparkSubnetId = null;
|
|
||||||
protected List<String> sparkSecurityGroupIds = Collections.emptyList();
|
|
||||||
protected int sparkInstanceCount = 1;
|
|
||||||
protected String sparkInstanceType = "m3.xlarge";
|
|
||||||
protected Optional<Float> sparkInstanceBidPrice = Optional.empty();
|
|
||||||
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
|
|
||||||
protected String sparkS3JarFolder = "changeme";
|
|
||||||
protected int sparkTimeoutDurationMinutes = 90;
|
|
||||||
|
|
||||||
//underlying configs
|
|
||||||
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
|
|
||||||
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
|
|
||||||
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
|
|
||||||
protected RunJobFlowRequest sparkRunJobFlowRequest;
|
|
||||||
protected Function<PutObjectRequest, PutObjectRequest> sparkS3PutObjectDecorator;
|
|
||||||
protected Map<String, String> sparkSubmitConfs;
|
|
||||||
|
|
||||||
|
|
||||||
private static ClusterState[] activeClusterStates = new ClusterState[]{
|
|
||||||
ClusterState.RUNNING,
|
|
||||||
ClusterState.STARTING,
|
|
||||||
ClusterState.WAITING,
|
|
||||||
ClusterState.BOOTSTRAPPING};
|
|
||||||
|
|
||||||
private Optional<ClusterSummary> findClusterWithName(AmazonElasticMapReduce emr, String name) {
|
|
||||||
List<ClusterSummary> csrl = emr.listClusters((new ListClustersRequest()).withClusterStates(activeClusterStates)).getClusters();
|
|
||||||
for (ClusterSummary csr : csrl) {
|
|
||||||
if (csr.getName().equals(name)) return Optional.of(csr);
|
|
||||||
}
|
|
||||||
return Optional.empty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates the current cluster
|
|
||||||
*/
|
|
||||||
public void createCluster() {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
Optional<ClusterSummary> csr = findClusterWithName(emr, sparkClusterName);
|
|
||||||
if (csr.isPresent()) {
|
|
||||||
String msg = String.format("A cluster with name %s and id %s is already deployed", sparkClusterName, csr.get().getId());
|
|
||||||
log.error(msg);
|
|
||||||
throw new IllegalStateException(msg);
|
|
||||||
} else {
|
|
||||||
RunJobFlowResult res = emr.runJobFlow(sparkRunJobFlowRequest);
|
|
||||||
String msg = String.format("Your cluster is launched with name %s and id %s.", sparkClusterName, res.getJobFlowId());
|
|
||||||
log.info(msg);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void logClusters(List<ClusterSummary> csrl) {
|
|
||||||
if (csrl.isEmpty()) log.info("No cluster found.");
|
|
||||||
else {
|
|
||||||
log.info(String.format("%d clusters found.", csrl.size()));
|
|
||||||
for (ClusterSummary csr : csrl) {
|
|
||||||
log.info(String.format("Name: %s | Id: %s", csr.getName(), csr.getId()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Lists existing active clusters Names
|
|
||||||
*
|
|
||||||
* @return cluster names
|
|
||||||
*/
|
|
||||||
public List<String> listActiveClusterNames() {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
List<ClusterSummary> csrl =
|
|
||||||
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
|
|
||||||
logClusters(csrl);
|
|
||||||
List<String> res = new ArrayList<>(csrl.size());
|
|
||||||
for (ClusterSummary csr : csrl) res.add(csr.getName());
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* List existing active cluster IDs
|
|
||||||
*
|
|
||||||
* @return cluster IDs
|
|
||||||
*/
|
|
||||||
public List<String> listActiveClusterIds() {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
List<ClusterSummary> csrl =
|
|
||||||
emr.listClusters(new ListClustersRequest().withClusterStates(activeClusterStates)).getClusters();
|
|
||||||
logClusters(csrl);
|
|
||||||
List<String> res = new ArrayList<>(csrl.size());
|
|
||||||
for (ClusterSummary csr : csrl) res.add(csr.getId());
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Terminates a cluster
|
|
||||||
*/
|
|
||||||
public void terminateCluster() {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
Optional<ClusterSummary> optClusterSum = findClusterWithName(emr, sparkClusterName);
|
|
||||||
if (!optClusterSum.isPresent()) {
|
|
||||||
log.error(String.format("The cluster with name %s , requested for deletion, does not exist.", sparkClusterName));
|
|
||||||
} else {
|
|
||||||
String id = optClusterSum.get().getId();
|
|
||||||
emr.terminateJobFlows((new TerminateJobFlowsRequest()).withJobFlowIds(id));
|
|
||||||
log.info(String.format("The cluster with id %s is terminating.", id));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// The actual job-sumission logic
|
|
||||||
private void submitJob(AmazonElasticMapReduce emr, String mainClass, List<String> args, Map<String, String> sparkConfs, File uberJar) throws Exception {
|
|
||||||
AmazonS3URI s3Jar = new AmazonS3URI(sparkS3JarFolder + "/" + uberJar.getName());
|
|
||||||
log.info(String.format("Placing uberJar %s to %s", uberJar.getPath(), s3Jar.toString()));
|
|
||||||
PutObjectRequest putRequest = sparkS3PutObjectDecorator.apply(
|
|
||||||
new PutObjectRequest(s3Jar.getBucket(), s3Jar.getKey(), uberJar)
|
|
||||||
);
|
|
||||||
sparkS3ClientBuilder.build().putObject(putRequest);
|
|
||||||
// The order of these matters
|
|
||||||
List<String> sparkSubmitArgs = Arrays.asList(
|
|
||||||
"spark-submit",
|
|
||||||
"--deploy-mode",
|
|
||||||
"cluster",
|
|
||||||
"--class",
|
|
||||||
mainClass
|
|
||||||
);
|
|
||||||
for (Map.Entry<String, String> e : sparkConfs.entrySet()) {
|
|
||||||
sparkSubmitArgs.add(String.format("--conf %s = %s ", e.getKey(), e.getValue()));
|
|
||||||
}
|
|
||||||
sparkSubmitArgs.add(s3Jar.toString());
|
|
||||||
sparkSubmitArgs.addAll(args);
|
|
||||||
StepConfig step = new StepConfig()
|
|
||||||
.withActionOnFailure(ActionOnFailure.CONTINUE)
|
|
||||||
.withName("Spark step")
|
|
||||||
.withHadoopJarStep(
|
|
||||||
new HadoopJarStepConfig()
|
|
||||||
.withJar("command-runner.jar")
|
|
||||||
.withArgs(sparkSubmitArgs)
|
|
||||||
);
|
|
||||||
|
|
||||||
Optional<ClusterSummary> optCsr = findClusterWithName(emr, sparkClusterName);
|
|
||||||
if (optCsr.isPresent()) {
|
|
||||||
ClusterSummary csr = optCsr.get();
|
|
||||||
emr.addJobFlowSteps(
|
|
||||||
new AddJobFlowStepsRequest()
|
|
||||||
.withJobFlowId(csr.getId())
|
|
||||||
.withSteps(step));
|
|
||||||
log.info(
|
|
||||||
String.format("Your job is added to the cluster with id %s.", csr.getId())
|
|
||||||
);
|
|
||||||
} else {
|
|
||||||
// If the cluster wasn't started, it's assumed ot be throwaway
|
|
||||||
List<StepConfig> steps = sparkRunJobFlowRequest.getSteps();
|
|
||||||
steps.add(step);
|
|
||||||
RunJobFlowRequest jobFlowRequest = sparkRunJobFlowRequest
|
|
||||||
.withSteps(steps)
|
|
||||||
.withInstances(sparkJobFlowInstancesConfig.withKeepJobFlowAliveWhenNoSteps(false));
|
|
||||||
|
|
||||||
RunJobFlowResult res = emr.runJobFlow(jobFlowRequest);
|
|
||||||
log.info("Your new cluster's id is %s.", res.getJobFlowId());
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Submit a Spark Job with a specified main class
|
|
||||||
*/
|
|
||||||
public void sparkSubmitJobWithMain(String[] args, String mainClass, File uberJar) throws Exception {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
submitJob(emr, mainClass, Arrays.asList(args), sparkSubmitConfs, uberJar);
|
|
||||||
}
|
|
||||||
|
|
||||||
private void checkStatus(AmazonElasticMapReduce emr, String clusterId) throws InterruptedException {
|
|
||||||
log.info(".");
|
|
||||||
com.amazonaws.services.elasticmapreduce.model.Cluster dcr =
|
|
||||||
emr.describeCluster((new DescribeClusterRequest()).withClusterId(clusterId)).getCluster();
|
|
||||||
String state = dcr.getStatus().getState();
|
|
||||||
long timeOutTime = System.currentTimeMillis() + ((long) sparkTimeoutDurationMinutes * 60 * 1000);
|
|
||||||
|
|
||||||
Boolean activated = Arrays.asList(activeClusterStates).contains(ClusterState.fromValue(state));
|
|
||||||
Boolean timedOut = System.currentTimeMillis() > timeOutTime;
|
|
||||||
if (activated && timedOut) {
|
|
||||||
emr.terminateJobFlows(
|
|
||||||
new TerminateJobFlowsRequest().withJobFlowIds(clusterId)
|
|
||||||
);
|
|
||||||
log.error("Timeout. Cluster terminated.");
|
|
||||||
} else if (!activated) {
|
|
||||||
Boolean hasAbnormalStep = false;
|
|
||||||
StepSummary stepS = null;
|
|
||||||
List<StepSummary> steps = emr.listSteps(new ListStepsRequest().withClusterId(clusterId)).getSteps();
|
|
||||||
for (StepSummary step : steps) {
|
|
||||||
if (step.getStatus().getState() != StepState.COMPLETED.toString()) {
|
|
||||||
hasAbnormalStep = true;
|
|
||||||
stepS = step;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (hasAbnormalStep && stepS != null)
|
|
||||||
log.error(String.format("Cluster %s terminated with an abnormal step, name %s, id %s", clusterId, stepS.getName(), stepS.getId()));
|
|
||||||
else
|
|
||||||
log.info("Cluster %s terminated without error.", clusterId);
|
|
||||||
} else {
|
|
||||||
Thread.sleep(5000);
|
|
||||||
checkStatus(emr, clusterId);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Monitor the cluster and terminates when it times out
|
|
||||||
*/
|
|
||||||
public void sparkMonitor() throws InterruptedException {
|
|
||||||
AmazonElasticMapReduce emr = sparkEmrClientBuilder.build();
|
|
||||||
Optional<ClusterSummary> optCsr = findClusterWithName(emr, sparkClusterName);
|
|
||||||
if (!optCsr.isPresent()) {
|
|
||||||
log.error(String.format("The cluster with name %s does not exist.", sparkClusterName));
|
|
||||||
} else {
|
|
||||||
ClusterSummary csr = optCsr.get();
|
|
||||||
log.info(String.format("found cluster with id %s, starting monitoring", csr.getId()));
|
|
||||||
checkStatus(emr, csr.getId());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Data
|
|
||||||
public static class Builder {
|
|
||||||
|
|
||||||
protected String sparkClusterName = RandomStringUtils.randomAlphanumeric(12);
|
|
||||||
protected String sparkAwsRegion = "us-east-1";
|
|
||||||
protected String sparkEmrRelease = "emr-5.9.0";
|
|
||||||
protected String sparkEmrServiceRole = "EMR_DefaultRole";
|
|
||||||
protected List<EmrConfig> sparkEmrConfigs = Collections.emptyList();
|
|
||||||
protected String sparkSubNetid = null;
|
|
||||||
protected List<String> sparkSecurityGroupIds = Collections.emptyList();
|
|
||||||
protected int sparkInstanceCount = 1;
|
|
||||||
protected String sparkInstanceType = "m3.xlarge";
|
|
||||||
protected Optional<Float> sparkInstanceBidPrice = Optional.empty();
|
|
||||||
protected String sparkInstanceRole = "EMR_EC2_DefaultRole";
|
|
||||||
protected String sparkS3JarFolder = "changeme";
|
|
||||||
protected int sparkTimeoutDurationMinutes = 90;
|
|
||||||
|
|
||||||
protected AmazonElasticMapReduceClientBuilder sparkEmrClientBuilder;
|
|
||||||
protected AmazonS3ClientBuilder sparkS3ClientBuilder;
|
|
||||||
protected JobFlowInstancesConfig sparkJobFlowInstancesConfig;
|
|
||||||
protected RunJobFlowRequest sparkRunJobFlowRequest;
|
|
||||||
|
|
||||||
// This should allow the user to decorate the put call to add metadata to the jar put command, such as security groups,
|
|
||||||
protected Function<PutObjectRequest, PutObjectRequest> sparkS3PutObjectDecorator = new Function<PutObjectRequest, PutObjectRequest>() {
|
|
||||||
@Override
|
|
||||||
public PutObjectRequest apply(PutObjectRequest putObjectRequest) {
|
|
||||||
return putObjectRequest;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
protected Map<String, String> sparkSubmitConfs;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Defines the EMR cluster's name
|
|
||||||
*
|
|
||||||
* @param clusterName the EMR cluster's name
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder clusterName(String clusterName) {
|
|
||||||
this.sparkClusterName = clusterName;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Defines the EMR cluster's region
|
|
||||||
* See https://docs.aws.amazon.com/general/latest/gr/rande.html
|
|
||||||
*
|
|
||||||
* @param region the EMR cluster's region
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder awsRegion(String region) {
|
|
||||||
this.sparkAwsRegion = region;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Defines the EMR release version to be used in this cluster
|
|
||||||
* uses a release label
|
|
||||||
* See https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-4.2.0/emr-release-differences.html#emr-release-label
|
|
||||||
*
|
|
||||||
* @param releaseLabel the EMR release label
|
|
||||||
* @return an EM cluster Builder
|
|
||||||
*/
|
|
||||||
public Builder emrRelease(String releaseLabel) {
|
|
||||||
this.sparkEmrRelease = releaseLabel;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Defines the IAM role to be assumed by the EMR service
|
|
||||||
* <p>
|
|
||||||
* https://docs.aws.amazon.com/IAM/latest/UserGuide/id_roles_create_for-service.html
|
|
||||||
*
|
|
||||||
* @param serviceRole the service role
|
|
||||||
* @return an EM cluster Builder
|
|
||||||
*/
|
|
||||||
public Builder emrServiceRole(String serviceRole) {
|
|
||||||
this.sparkEmrServiceRole = serviceRole;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A list of configuration parameters to apply to EMR instances.
|
|
||||||
*
|
|
||||||
* @param configs the EMR configurations to apply to this cluster
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder emrConfigs(List<EmrConfig> configs) {
|
|
||||||
this.sparkEmrConfigs = configs;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The id of the EC2 subnet to be used for this Spark EMR service
|
|
||||||
* see https://docs.aws.amazon.com/AmazonVPC/latest/UserGuide/VPC_Subnets.html
|
|
||||||
*
|
|
||||||
* @param id the subnet ID
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder subnetId(String id) {
|
|
||||||
this.sparkSubNetid = id;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The id of additional security groups this deployment should adopt for both master and slaves
|
|
||||||
*
|
|
||||||
* @param securityGroups
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder securityGroupIDs(List<String> securityGroups) {
|
|
||||||
this.sparkSecurityGroupIds = securityGroups;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The number of instances this deployment should comprise of
|
|
||||||
*
|
|
||||||
* @param count the number of instances for this cluster
|
|
||||||
* @rturn an EMR cluster buidler
|
|
||||||
*/
|
|
||||||
public Builder instanceCount(int count) {
|
|
||||||
this.sparkInstanceCount = count;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The type of instance this cluster should comprise of
|
|
||||||
* See https://aws.amazon.com/ec2/instance-types/
|
|
||||||
*
|
|
||||||
* @param instanceType the type of instance for this cluster
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder instanceType(String instanceType) {
|
|
||||||
this.sparkInstanceType = instanceType;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The optional bid value for this cluster's spot instances
|
|
||||||
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/how-spot-instances-work.html
|
|
||||||
* Uses the on-demand market if empty.
|
|
||||||
*
|
|
||||||
* @param optBid the Optional bid price for this cluster's instnces
|
|
||||||
* @return an EMR cluster Builder
|
|
||||||
*/
|
|
||||||
public Builder instanceBidPrice(Optional<Float> optBid) {
|
|
||||||
this.sparkInstanceBidPrice = optBid;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The EC2 instance role that this cluster's instances should assume
|
|
||||||
* see https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html
|
|
||||||
*
|
|
||||||
* @param role the intended instance role
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder instanceRole(String role) {
|
|
||||||
this.sparkInstanceRole = role;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* the S3 folder in which to find the application jar
|
|
||||||
*
|
|
||||||
* @param jarfolder the S3 folder in which to find a jar
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder s3JarFolder(String jarfolder) {
|
|
||||||
this.sparkS3JarFolder = jarfolder;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The timeout duration for this Spark EMR cluster, in minutes
|
|
||||||
*
|
|
||||||
* @param timeoutMinutes
|
|
||||||
* @return an EMR cluster builder
|
|
||||||
*/
|
|
||||||
public Builder sparkTimeOutDurationMinutes(int timeoutMinutes) {
|
|
||||||
this.sparkTimeoutDurationMinutes = timeoutMinutes;
|
|
||||||
return this;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Creates an EMR Spark cluster deployment
|
|
||||||
*
|
|
||||||
* @return a SparkEMRClient
|
|
||||||
*/
|
|
||||||
public SparkEMRClient build() {
|
|
||||||
this.sparkEmrClientBuilder = AmazonElasticMapReduceClientBuilder.standard().withRegion(sparkAwsRegion);
|
|
||||||
this.sparkS3ClientBuilder = AmazonS3ClientBuilder.standard().withRegion(sparkAwsRegion);
|
|
||||||
// note this will be kept alive without steps, an arbitrary choice to avoid rapid test-teardown-restart cycles
|
|
||||||
this.sparkJobFlowInstancesConfig = (new JobFlowInstancesConfig()).withKeepJobFlowAliveWhenNoSteps(true);
|
|
||||||
if (this.sparkSubNetid != null)
|
|
||||||
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withEc2SubnetId(this.sparkSubNetid);
|
|
||||||
if (!this.sparkSecurityGroupIds.isEmpty()) {
|
|
||||||
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalMasterSecurityGroups(this.sparkSecurityGroupIds);
|
|
||||||
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withAdditionalSlaveSecurityGroups(this.sparkSecurityGroupIds);
|
|
||||||
}
|
|
||||||
|
|
||||||
InstanceGroupConfig masterConfig =
|
|
||||||
(new InstanceGroupConfig()).withInstanceCount(1).withInstanceRole(InstanceRoleType.MASTER).withInstanceType(sparkInstanceType);
|
|
||||||
if (sparkInstanceBidPrice.isPresent()) {
|
|
||||||
masterConfig = masterConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
|
|
||||||
} else {
|
|
||||||
masterConfig = masterConfig.withMarket(MarketType.ON_DEMAND);
|
|
||||||
}
|
|
||||||
|
|
||||||
int slaveCount = sparkInstanceCount - 1;
|
|
||||||
InstanceGroupConfig slaveConfig =
|
|
||||||
(new InstanceGroupConfig()).withInstanceCount(slaveCount).withInstanceRole(InstanceRoleType.CORE).withInstanceRole(sparkInstanceType);
|
|
||||||
if (sparkInstanceBidPrice.isPresent()) {
|
|
||||||
slaveConfig = slaveConfig.withMarket(MarketType.SPOT).withBidPrice(sparkInstanceBidPrice.get().toString());
|
|
||||||
} else {
|
|
||||||
slaveConfig = slaveConfig.withMarket(MarketType.ON_DEMAND);
|
|
||||||
}
|
|
||||||
if (slaveCount > 0) {
|
|
||||||
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(Arrays.asList(masterConfig, slaveConfig));
|
|
||||||
} else {
|
|
||||||
this.sparkJobFlowInstancesConfig = this.sparkJobFlowInstancesConfig.withInstanceGroups(slaveConfig);
|
|
||||||
}
|
|
||||||
|
|
||||||
this.sparkRunJobFlowRequest = new RunJobFlowRequest();
|
|
||||||
if (!sparkEmrConfigs.isEmpty()) {
|
|
||||||
List<Configuration> emrConfigs = new ArrayList<>();
|
|
||||||
for (EmrConfig config : sparkEmrConfigs) {
|
|
||||||
emrConfigs.add(config.toAwsConfig());
|
|
||||||
}
|
|
||||||
this.sparkRunJobFlowRequest = this.sparkRunJobFlowRequest.withConfigurations(emrConfigs);
|
|
||||||
}
|
|
||||||
this.sparkRunJobFlowRequest =
|
|
||||||
this.sparkRunJobFlowRequest.withName(sparkClusterName).withApplications((new Application()).withName("Spark"))
|
|
||||||
.withReleaseLabel(sparkEmrRelease)
|
|
||||||
.withServiceRole(sparkEmrServiceRole)
|
|
||||||
.withJobFlowRole(sparkInstanceRole)
|
|
||||||
.withInstances(this.sparkJobFlowInstancesConfig);
|
|
||||||
|
|
||||||
return new SparkEMRClient(
|
|
||||||
this.sparkClusterName,
|
|
||||||
this.sparkAwsRegion,
|
|
||||||
this.sparkEmrRelease,
|
|
||||||
this.sparkEmrServiceRole,
|
|
||||||
this.sparkEmrConfigs,
|
|
||||||
this.sparkSubNetid,
|
|
||||||
this.sparkSecurityGroupIds,
|
|
||||||
this.sparkInstanceCount,
|
|
||||||
this.sparkInstanceType,
|
|
||||||
this.sparkInstanceBidPrice,
|
|
||||||
this.sparkInstanceRole,
|
|
||||||
this.sparkS3JarFolder,
|
|
||||||
this.sparkTimeoutDurationMinutes,
|
|
||||||
this.sparkEmrClientBuilder,
|
|
||||||
this.sparkS3ClientBuilder,
|
|
||||||
this.sparkJobFlowInstancesConfig,
|
|
||||||
this.sparkRunJobFlowRequest,
|
|
||||||
this.sparkS3PutObjectDecorator,
|
|
||||||
this.sparkSubmitConfs
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,117 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3;
|
|
||||||
|
|
||||||
import com.amazonaws.auth.AWSCredentials;
|
|
||||||
import com.amazonaws.auth.BasicAWSCredentials;
|
|
||||||
import com.amazonaws.auth.PropertiesCredentials;
|
|
||||||
import com.amazonaws.services.ec2.AmazonEC2;
|
|
||||||
import com.amazonaws.services.ec2.AmazonEC2Client;
|
|
||||||
import com.amazonaws.services.s3.AmazonS3;
|
|
||||||
import com.amazonaws.services.s3.AmazonS3Client;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.InputStream;
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* The S3 Credentials works via discovering the credentials
|
|
||||||
* from the system properties (passed in via -D or System wide)
|
|
||||||
* If you invoke the JVM with -Dorg.deeplearning4j.aws.accessKey=YOUR_ACCESS_KEY
|
|
||||||
* and -Dorg.deeplearning4j.aws.accessSecret=YOUR_SECRET_KEY
|
|
||||||
* this will pick up the credentials from there, otherwise it will also attempt to look in
|
|
||||||
* the system environment for the following variables:
|
|
||||||
*
|
|
||||||
*
|
|
||||||
* AWS_ACCESS_KEY_ID
|
|
||||||
* AWS_SECRET_ACCESS_KEY
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract class BaseS3 {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
protected static final long serialVersionUID = -2280107690193651289L;
|
|
||||||
protected String accessKey;
|
|
||||||
protected String secretKey;
|
|
||||||
protected AWSCredentials creds;
|
|
||||||
public final static String ACCESS_KEY = "org.deeplearning4j.aws.accessKey";
|
|
||||||
public final static String ACCESS_SECRET = "org.deeplearning4j.aws.accessSecret";
|
|
||||||
public final static String AWS_ACCESS_KEY = "AWS_ACCESS_KEY"; //"AWS_ACCESS_KEY_ID";
|
|
||||||
public final static String AWS_SECRET_KEY = "AWS_SECRET_KEY"; //"AWS_SECRET_ACCESS_KEY";
|
|
||||||
|
|
||||||
|
|
||||||
protected void findCreds() {
|
|
||||||
if (System.getProperty(ACCESS_KEY) != null && System.getProperty(ACCESS_SECRET) != null) {
|
|
||||||
accessKey = System.getProperty(ACCESS_KEY);
|
|
||||||
secretKey = System.getProperty(ACCESS_SECRET);
|
|
||||||
}
|
|
||||||
|
|
||||||
else if (System.getenv(AWS_ACCESS_KEY) != null && System.getenv(AWS_SECRET_KEY) != null) {
|
|
||||||
accessKey = System.getenv(AWS_ACCESS_KEY);
|
|
||||||
secretKey = System.getenv(AWS_SECRET_KEY);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseS3() {
|
|
||||||
findCreds();
|
|
||||||
if (accessKey != null && secretKey != null)
|
|
||||||
creds = new BasicAWSCredentials(accessKey, secretKey);
|
|
||||||
if (creds == null)
|
|
||||||
throw new IllegalStateException("Unable to find ec2 credentials");
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseS3(File file) throws Exception {
|
|
||||||
if (accessKey != null && secretKey != null)
|
|
||||||
creds = new BasicAWSCredentials(accessKey, secretKey);
|
|
||||||
else
|
|
||||||
creds = new PropertiesCredentials(file);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public BaseS3(InputStream is) throws Exception {
|
|
||||||
if (accessKey != null && secretKey != null)
|
|
||||||
creds = new BasicAWSCredentials(accessKey, secretKey);
|
|
||||||
else
|
|
||||||
creds = new PropertiesCredentials(is);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public AWSCredentials getCreds() {
|
|
||||||
return creds;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void setCreds(AWSCredentials creds) {
|
|
||||||
this.creds = creds;
|
|
||||||
}
|
|
||||||
|
|
||||||
public AmazonS3 getClient() {
|
|
||||||
return new AmazonS3Client(creds);
|
|
||||||
}
|
|
||||||
|
|
||||||
public AmazonEC2 getEc2() {
|
|
||||||
|
|
||||||
return new AmazonEC2Client(creds);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,83 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3.reader;
|
|
||||||
|
|
||||||
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* baseline data applyTransformToDestination iterator for
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public abstract class BaseS3DataSetIterator implements DataSetIterator {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
private static final long serialVersionUID = 885205002006822431L;
|
|
||||||
private S3Downloader downloader;
|
|
||||||
private List<String> buckets;
|
|
||||||
private int currBucket;
|
|
||||||
private Iterator<InputStream> currIterator;
|
|
||||||
|
|
||||||
public BaseS3DataSetIterator() {
|
|
||||||
downloader = new S3Downloader();
|
|
||||||
buckets = downloader.buckets();
|
|
||||||
currBucket = 0;
|
|
||||||
currIterator = downloader.iterateBucket(buckets.get(currBucket));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public InputStream nextObject() {
|
|
||||||
if (currIterator.hasNext())
|
|
||||||
return currIterator.next();
|
|
||||||
else if (currBucket < buckets.size()) {
|
|
||||||
currBucket++;
|
|
||||||
currIterator = downloader.iterateBucket(buckets.get(currBucket));
|
|
||||||
return currIterator.next();
|
|
||||||
}
|
|
||||||
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
return currBucket < buckets.size() && currIterator.hasNext();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public String currBucket() {
|
|
||||||
return buckets.get(currBucket);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public void nextBucket() {
|
|
||||||
currBucket++;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,93 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3.reader;
|
|
||||||
|
|
||||||
import com.amazonaws.services.s3.model.ObjectListing;
|
|
||||||
import com.amazonaws.services.s3.model.S3ObjectSummary;
|
|
||||||
|
|
||||||
import java.io.InputStream;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterator over individual S3 bucket
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class BucketIterator implements Iterator<InputStream> {
|
|
||||||
|
|
||||||
private S3Downloader s3;
|
|
||||||
private String bucket;
|
|
||||||
private ObjectListing currList;
|
|
||||||
private List<S3ObjectSummary> currObjects;
|
|
||||||
private int currObject;
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
public BucketIterator(String bucket) {
|
|
||||||
this(bucket, null);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public BucketIterator(String bucket, S3Downloader s3) {
|
|
||||||
this.bucket = bucket;
|
|
||||||
|
|
||||||
if (s3 == null)
|
|
||||||
this.s3 = new S3Downloader();
|
|
||||||
else
|
|
||||||
this.s3 = s3;
|
|
||||||
currList = this.s3.listObjects(bucket);
|
|
||||||
currObjects = currList.getObjectSummaries();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean hasNext() {
|
|
||||||
return currObject < currObjects.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public InputStream next() {
|
|
||||||
if (currObject < currObjects.size()) {
|
|
||||||
InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey());
|
|
||||||
currObject++;
|
|
||||||
return ret;
|
|
||||||
} else if (currList.isTruncated()) {
|
|
||||||
currList = s3.nextList(currList);
|
|
||||||
currObjects = currList.getObjectSummaries();
|
|
||||||
currObject = 0;
|
|
||||||
|
|
||||||
InputStream ret = s3.objectForKey(bucket, currObjects.get(currObject).getKey());
|
|
||||||
|
|
||||||
currObject++;
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
throw new IllegalStateException("Indeterminate state");
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void remove() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,38 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3.reader;
|
|
||||||
|
|
||||||
import com.amazonaws.services.s3.AmazonS3;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* When paginating through a result applyTransformToDestination,
|
|
||||||
* allows the user to react to a bucket result being found
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public interface BucketKeyListener {
|
|
||||||
|
|
||||||
/**
|
|
||||||
*
|
|
||||||
* @param s3 an s3 client
|
|
||||||
* @param bucket the bucket being iterated on
|
|
||||||
* @param key the current key
|
|
||||||
*/
|
|
||||||
void onKey(AmazonS3 s3, String bucket, String key);
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,178 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3.reader;
|
|
||||||
|
|
||||||
import com.amazonaws.services.s3.AmazonS3;
|
|
||||||
import com.amazonaws.services.s3.model.*;
|
|
||||||
import com.amazonaws.services.s3.transfer.MultipleFileDownload;
|
|
||||||
import com.amazonaws.services.s3.transfer.TransferManager;
|
|
||||||
import org.apache.commons.io.IOUtils;
|
|
||||||
import org.deeplearning4j.aws.s3.BaseS3;
|
|
||||||
|
|
||||||
import java.io.*;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Downloads files from S3
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class S3Downloader extends BaseS3 {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Return the keys for a bucket
|
|
||||||
* @param bucket the bucket to get the keys for
|
|
||||||
* @return the bucket's keys
|
|
||||||
*/
|
|
||||||
public List<String> keysForBucket(String bucket) {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
List<String> ret = new ArrayList<>();
|
|
||||||
ListObjectsRequest listObjectsRequest = new ListObjectsRequest().withBucketName(bucket);
|
|
||||||
ObjectListing objectListing;
|
|
||||||
|
|
||||||
do {
|
|
||||||
objectListing = s3.listObjects(listObjectsRequest);
|
|
||||||
for (S3ObjectSummary objectSummary : objectListing.getObjectSummaries()) {
|
|
||||||
ret.add(objectSummary.getKey());
|
|
||||||
}
|
|
||||||
listObjectsRequest.setMarker(objectListing.getNextMarker());
|
|
||||||
} while (objectListing.isTruncated());
|
|
||||||
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns the list of buckets in s3
|
|
||||||
* @return the list of buckets
|
|
||||||
*/
|
|
||||||
public List<String> buckets() {
|
|
||||||
List<String> ret = new ArrayList<>();
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
List<Bucket> buckets = s3.listBuckets();
|
|
||||||
for (Bucket b : buckets)
|
|
||||||
ret.add(b.getName());
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterate over individual buckets.
|
|
||||||
* Returns input streams to each object.
|
|
||||||
* It is your responsibility to close the input streams
|
|
||||||
* @param bucket the bucket to iterate over
|
|
||||||
* @return an iterator over the objects in an s3 bucket
|
|
||||||
*/
|
|
||||||
public Iterator<InputStream> iterateBucket(String bucket) {
|
|
||||||
return new BucketIterator(bucket, this);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Iterator style one list at a time
|
|
||||||
* @param list the list to get the next batch for
|
|
||||||
* @return the next batch of objects or null if
|
|
||||||
* none are left
|
|
||||||
*/
|
|
||||||
public ObjectListing nextList(ObjectListing list) {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
if (list.isTruncated())
|
|
||||||
return s3.listNextBatchOfObjects(list);
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Simple way of retrieving the listings for a bucket
|
|
||||||
* @param bucket the bucket to retrieve listings for
|
|
||||||
* @return the object listing for this bucket
|
|
||||||
*/
|
|
||||||
public ObjectListing listObjects(String bucket) {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
ObjectListing list = s3.listObjects(bucket);
|
|
||||||
return list;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Paginates through a bucket's keys invoking the listener
|
|
||||||
* at each key
|
|
||||||
* @param bucket the bucket to iterate
|
|
||||||
* @param listener the listener
|
|
||||||
*/
|
|
||||||
public void paginate(String bucket, BucketKeyListener listener) {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
ObjectListing list = s3.listObjects(bucket);
|
|
||||||
for (S3ObjectSummary summary : list.getObjectSummaries()) {
|
|
||||||
if (listener != null)
|
|
||||||
listener.onKey(s3, bucket, summary.getKey());
|
|
||||||
}
|
|
||||||
|
|
||||||
while (list.isTruncated()) {
|
|
||||||
list = s3.listNextBatchOfObjects(list);
|
|
||||||
for (S3ObjectSummary summary : list.getObjectSummaries()) {
|
|
||||||
if (listener != null)
|
|
||||||
listener.onKey(s3, bucket, summary.getKey());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Returns an input stream for the given bucket and key
|
|
||||||
* @param bucket the bucket to retrieve from
|
|
||||||
* @param key the key of the objec t
|
|
||||||
* @return an input stream to the object
|
|
||||||
*/
|
|
||||||
public InputStream objectForKey(String bucket, String key) {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
S3Object obj = s3.getObject(bucket, key);
|
|
||||||
InputStream is = obj.getObjectContent();
|
|
||||||
return is;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public void download(String bucket, String key, File to) throws IOException {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
S3Object obj = s3.getObject(bucket, key);
|
|
||||||
InputStream is = obj.getObjectContent();
|
|
||||||
BufferedOutputStream bos = new BufferedOutputStream(new FileOutputStream(to));
|
|
||||||
IOUtils.copy(is, bos);
|
|
||||||
bos.close();
|
|
||||||
is.close();
|
|
||||||
obj.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
public void download(String bucket, String key, OutputStream to) throws IOException {
|
|
||||||
AmazonS3 s3 = getClient();
|
|
||||||
S3Object obj = s3.getObject(bucket, key);
|
|
||||||
InputStream is = obj.getObjectContent();
|
|
||||||
BufferedOutputStream bos = new BufferedOutputStream(to);
|
|
||||||
|
|
||||||
IOUtils.copy(is, bos);
|
|
||||||
bos.close();
|
|
||||||
is.close();
|
|
||||||
obj.close();
|
|
||||||
}
|
|
||||||
|
|
||||||
public MultipleFileDownload downloadFolder(String bucketName, String keyPrefix, File folderPath) {
|
|
||||||
TransferManager transfer = new TransferManager(getClient());
|
|
||||||
return transfer.downloadDirectory(bucketName, keyPrefix, folderPath);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.aws.s3.uploader;
|
|
||||||
|
|
||||||
import com.amazonaws.services.s3.AmazonS3;
|
|
||||||
import com.amazonaws.services.s3.AmazonS3Client;
|
|
||||||
import com.amazonaws.services.s3.model.*;
|
|
||||||
import com.amazonaws.services.s3.transfer.MultipleFileUpload;
|
|
||||||
import com.amazonaws.services.s3.transfer.TransferManager;
|
|
||||||
import org.deeplearning4j.aws.s3.BaseS3;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Uploads files to S3
|
|
||||||
*
|
|
||||||
* @see {@link BaseS3}
|
|
||||||
* @author Adam Gibson
|
|
||||||
*
|
|
||||||
*/
|
|
||||||
public class S3Uploader extends BaseS3 {
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Multi part upload for big files
|
|
||||||
* @param file the file to upload
|
|
||||||
* @param bucketName the bucket name to upload
|
|
||||||
*/
|
|
||||||
public void multiPartUpload(File file, String bucketName) {
|
|
||||||
AmazonS3 client = new AmazonS3Client(creds);
|
|
||||||
bucketName = ensureValidBucketName(bucketName);
|
|
||||||
|
|
||||||
List<Bucket> buckets = client.listBuckets();
|
|
||||||
for (Bucket b : buckets)
|
|
||||||
if (b.getName().equals(bucketName)) {
|
|
||||||
doMultiPart(client, bucketName, file);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//bucket didn't exist: create it
|
|
||||||
client.createBucket(bucketName);
|
|
||||||
doMultiPart(client, bucketName, file);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Upload the file to the bucket.
|
|
||||||
* Will create the bucket if it hasn't already been created
|
|
||||||
* @param file the file to upload
|
|
||||||
* @param bucketName the name of the bucket
|
|
||||||
*/
|
|
||||||
public void upload(File file, String bucketName) {
|
|
||||||
AmazonS3 client = new AmazonS3Client(creds);
|
|
||||||
bucketName = ensureValidBucketName(bucketName);
|
|
||||||
|
|
||||||
List<Bucket> buckets = client.listBuckets();
|
|
||||||
for (Bucket b : buckets)
|
|
||||||
if (b.getName().equals(bucketName)) {
|
|
||||||
client.putObject(bucketName, file.getName(), file);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//bucket didn't exist: create it
|
|
||||||
client.createBucket(bucketName);
|
|
||||||
client.putObject(bucketName, file.getName(), file);
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
private void doMultiPart(AmazonS3 s3Client, String bucketName, File file) {
|
|
||||||
// Create a list of UploadPartResponse objects. You get one of these
|
|
||||||
// for each part upload.
|
|
||||||
List<PartETag> partETags = new ArrayList<>();
|
|
||||||
|
|
||||||
// Step 1: Initialize.
|
|
||||||
InitiateMultipartUploadRequest initRequest = new InitiateMultipartUploadRequest(bucketName, file.getName());
|
|
||||||
InitiateMultipartUploadResult initResponse = s3Client.initiateMultipartUpload(initRequest);
|
|
||||||
|
|
||||||
long contentLength = file.length();
|
|
||||||
long partSize = 5242880; // Set part size to 5 MB.
|
|
||||||
|
|
||||||
try {
|
|
||||||
// Step 2: Upload parts.
|
|
||||||
long filePosition = 0;
|
|
||||||
for (int i = 1; filePosition < contentLength; i++) {
|
|
||||||
// Last part can be less than 5 MB. Adjust part size.
|
|
||||||
partSize = Math.min(partSize, (contentLength - filePosition));
|
|
||||||
|
|
||||||
// Create request to upload a part.
|
|
||||||
UploadPartRequest uploadRequest = new UploadPartRequest().withBucketName(bucketName)
|
|
||||||
.withKey(file.getName()).withUploadId(initResponse.getUploadId()).withPartNumber(i)
|
|
||||||
.withFileOffset(filePosition).withFile(file).withPartSize(partSize);
|
|
||||||
|
|
||||||
// Upload part and add response to our list.
|
|
||||||
partETags.add(s3Client.uploadPart(uploadRequest).getPartETag());
|
|
||||||
|
|
||||||
filePosition += partSize;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 3: Complete.
|
|
||||||
CompleteMultipartUploadRequest compRequest = new CompleteMultipartUploadRequest(bucketName, file.getName(),
|
|
||||||
initResponse.getUploadId(), partETags);
|
|
||||||
|
|
||||||
s3Client.completeMultipartUpload(compRequest);
|
|
||||||
} catch (Exception e) {
|
|
||||||
s3Client.abortMultipartUpload(
|
|
||||||
new AbortMultipartUploadRequest(bucketName, file.getName(), initResponse.getUploadId()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private String ensureValidBucketName(String bucketName) {
|
|
||||||
String formatted = bucketName.replaceAll("\\s+", "_");
|
|
||||||
int length = bucketName.length();
|
|
||||||
if (length >= 62)
|
|
||||||
length = 62;
|
|
||||||
formatted = formatted.substring(0, length);
|
|
||||||
formatted = formatted.replace(".", "d");
|
|
||||||
formatted = formatted.toLowerCase();
|
|
||||||
if (formatted.endsWith("-"))
|
|
||||||
formatted = formatted.substring(0, length - 1);
|
|
||||||
|
|
||||||
return formatted;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void upload(File file, String name, String bucketName) {
|
|
||||||
AmazonS3 client = getClient();
|
|
||||||
bucketName = ensureValidBucketName(bucketName);
|
|
||||||
List<Bucket> buckets = client.listBuckets();
|
|
||||||
// ObjectMetadata med = new ObjectMetadata();
|
|
||||||
// med.setContentLength(fileLength);
|
|
||||||
for (Bucket b : buckets)
|
|
||||||
if (b.getName().equals(bucketName)) {
|
|
||||||
//client.putObject(bucketName, name, is, med);
|
|
||||||
client.putObject(new PutObjectRequest(bucketName, name, file));
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
//bucket didn't exist: createComplex it
|
|
||||||
client.createBucket(bucketName);
|
|
||||||
//client.putObject(bucketName, name, is, med);
|
|
||||||
client.putObject(new PutObjectRequest(bucketName, name, file));
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public MultipleFileUpload uploadFolder(String bucketName, String keyPrefix, File folderPath,
|
|
||||||
boolean includeSubDir) {
|
|
||||||
TransferManager transfer = new TransferManager(getClient());
|
|
||||||
return transfer.uploadDirectory(bucketName, keyPrefix, folderPath, includeSubDir);
|
|
||||||
}
|
|
||||||
|
|
||||||
public MultipleFileUpload uploadFileList(String bucketName, File folderPath, List<File> fileList,
|
|
||||||
String keyPrefix) {
|
|
||||||
TransferManager transfer = new TransferManager(getClient());
|
|
||||||
return transfer.uploadFileList(bucketName, keyPrefix, folderPath, fileList);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,45 +0,0 @@
|
||||||
#!/bin/bash
|
|
||||||
|
|
||||||
################################################################################
|
|
||||||
# Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
#
|
|
||||||
# This program and the accompanying materials are made available under the
|
|
||||||
# terms of the Apache License, Version 2.0 which is available at
|
|
||||||
# https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
# License for the specific language governing permissions and limitations
|
|
||||||
# under the License.
|
|
||||||
#
|
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
|
||||||
################################################################################
|
|
||||||
|
|
||||||
sudo yum -y install blas java-1.7.0-openjdk-devel
|
|
||||||
|
|
||||||
if [ ! -f "/opt/dl4j" ];
|
|
||||||
then
|
|
||||||
sudo mkdir /opt/dl4j
|
|
||||||
sudo yum -y install git
|
|
||||||
git clone https://github.com/agibsonccc/java-deeplearning
|
|
||||||
|
|
||||||
wget http://www.trieuvan.com/apache/maven/maven-3/3.1.1/binaries/apache-maven-3.1.1-bin.tar.gz
|
|
||||||
sudo mv apache-maven-3.1.1-bin.tar.gz /opt
|
|
||||||
cd /opt && sudo tar xvf apache-maven-3.1.1-bin.tar.gz && sudo mv apache-maven-3.1.1 /opt/mvn
|
|
||||||
cd /home/ec2-user/java-deeplearning/ && /opt/mvn/bin/mvn -DskipTests clean install
|
|
||||||
echo "Printing distribution"
|
|
||||||
ls /home/ec2-user/java-deeplearning/deeplearning4j-distribution/target
|
|
||||||
echo "Before moving distribution"
|
|
||||||
sudo mv deeplearning4j-distribution/target/deeplearning4j-dist.tar.gz /opt
|
|
||||||
echo "Moving distribution to opt directory..."
|
|
||||||
echo "Moving in to opt directory"
|
|
||||||
cd /opt
|
|
||||||
|
|
||||||
sudo tar xzvf deeplearning4j-dist.tar.gz
|
|
||||||
#sudo mkdir /opt/dl4j
|
|
||||||
echo "Moving jars in to /opt/dl4j/..."
|
|
||||||
sudo mv *.jar /opt/dl4j
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,6 @@
|
||||||
<name>DeepLearning4j-scaleout-parent</name>
|
<name>DeepLearning4j-scaleout-parent</name>
|
||||||
|
|
||||||
<modules>
|
<modules>
|
||||||
<module>deeplearning4j-aws</module>
|
|
||||||
<module>spark</module>
|
<module>spark</module>
|
||||||
<module>deeplearning4j-scaleout-parallelwrapper</module>
|
<module>deeplearning4j-scaleout-parallelwrapper</module>
|
||||||
<module>deeplearning4j-scaleout-parallelwrapper-parameter-server</module>
|
<module>deeplearning4j-scaleout-parallelwrapper-parameter-server</module>
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>deeplearning4j-parent</artifactId>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>deeplearning4j-util</artifactId>
|
|
||||||
<packaging>jar</packaging>
|
|
||||||
|
|
||||||
<name>deeplearning4j-util</name>
|
|
||||||
<url>http://maven.apache.org</url>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-api</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-common</artifactId>
|
|
||||||
<version>${nd4j.version}</version>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,193 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.util;
|
|
||||||
|
|
||||||
import org.apache.commons.io.FileUtils;
|
|
||||||
import org.nd4j.linalg.util.SerializationUtils;
|
|
||||||
|
|
||||||
import java.io.File;
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.Serializable;
|
|
||||||
import java.util.Collection;
|
|
||||||
import java.util.Iterator;
|
|
||||||
import java.util.Queue;
|
|
||||||
import java.util.UUID;
|
|
||||||
import java.util.concurrent.ConcurrentLinkedDeque;
|
|
||||||
import java.util.concurrent.Executors;
|
|
||||||
import java.util.concurrent.atomic.AtomicBoolean;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Naive disk based queue for storing items on disk.
|
|
||||||
* Only meant for poll and adding items.
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class DiskBasedQueue<E> implements Queue<E>, Serializable {
|
|
||||||
|
|
||||||
private File dir;
|
|
||||||
private Queue<String> paths = new ConcurrentLinkedDeque<>();
|
|
||||||
private AtomicBoolean running = new AtomicBoolean(true);
|
|
||||||
private Queue<E> save = new ConcurrentLinkedDeque<>();
|
|
||||||
|
|
||||||
public DiskBasedQueue() {
|
|
||||||
this(".queue");
|
|
||||||
}
|
|
||||||
|
|
||||||
public DiskBasedQueue(String path) {
|
|
||||||
this(new File(path));
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
public DiskBasedQueue(File dir) {
|
|
||||||
this.dir = dir;
|
|
||||||
if (!dir.exists() && dir.isDirectory()) {
|
|
||||||
throw new IllegalArgumentException("Illegal queue: must be a directory");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!dir.exists())
|
|
||||||
dir.mkdirs();
|
|
||||||
if (dir.listFiles() != null && dir.listFiles().length > 1)
|
|
||||||
try {
|
|
||||||
FileUtils.deleteDirectory(dir);
|
|
||||||
} catch (IOException e) {
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
dir.mkdir();
|
|
||||||
|
|
||||||
Thread t = Executors.defaultThreadFactory().newThread(new Runnable() {
|
|
||||||
@Override
|
|
||||||
public void run() {
|
|
||||||
while (running.get()) {
|
|
||||||
while (!save.isEmpty())
|
|
||||||
addAndSave(save.poll());
|
|
||||||
|
|
||||||
ThreadUtils.uncheckedSleep(1000);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
t.setName("DiskBasedQueueSaver");
|
|
||||||
t.setDaemon(true);
|
|
||||||
t.start();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public int size() {
|
|
||||||
return paths.size();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean isEmpty() {
|
|
||||||
return paths.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean contains(Object o) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Iterator<E> iterator() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public Object[] toArray() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public <T> T[] toArray(T[] a) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean add(E e) {
|
|
||||||
save.add(e);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean remove(Object o) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean containsAll(Collection<?> c) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean addAll(Collection<? extends E> c) {
|
|
||||||
for (E e : c)
|
|
||||||
addAndSave(e);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean removeAll(Collection<?> c) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean retainAll(Collection<?> c) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void clear() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean offer(E e) {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public E remove() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public E poll() {
|
|
||||||
String path = paths.poll();
|
|
||||||
E ret = SerializationUtils.readObject(new File(path));
|
|
||||||
File item = new File(path);
|
|
||||||
item.delete();
|
|
||||||
return ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public E element() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public E peek() {
|
|
||||||
throw new UnsupportedOperationException();
|
|
||||||
}
|
|
||||||
|
|
||||||
private void addAndSave(E e) {
|
|
||||||
File path = new File(dir, UUID.randomUUID().toString());
|
|
||||||
SerializationUtils.saveObject(e, path);
|
|
||||||
paths.add(path.getAbsolutePath());
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,132 +0,0 @@
|
||||||
/*******************************************************************************
|
|
||||||
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
*
|
|
||||||
* This program and the accompanying materials are made available under the
|
|
||||||
* terms of the Apache License, Version 2.0 which is available at
|
|
||||||
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
*
|
|
||||||
* Unless required by applicable law or agreed to in writing, software
|
|
||||||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
* License for the specific language governing permissions and limitations
|
|
||||||
* under the License.
|
|
||||||
*
|
|
||||||
* SPDX-License-Identifier: Apache-2.0
|
|
||||||
******************************************************************************/
|
|
||||||
|
|
||||||
package org.deeplearning4j.util;
|
|
||||||
|
|
||||||
|
|
||||||
import java.lang.reflect.Constructor;
|
|
||||||
import java.lang.reflect.Field;
|
|
||||||
import java.lang.reflect.Modifier;
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
import java.util.Properties;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @author Adam Gibson
|
|
||||||
*/
|
|
||||||
public class Dl4jReflection {
|
|
||||||
private Dl4jReflection() {}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Gets the empty constructor from a class
|
|
||||||
* @param clazz the class to get the constructor from
|
|
||||||
* @return the empty constructor for the class
|
|
||||||
*/
|
|
||||||
public static Constructor<?> getEmptyConstructor(Class<?> clazz) {
|
|
||||||
Constructor<?> c = clazz.getDeclaredConstructors()[0];
|
|
||||||
for (int i = 0; i < clazz.getDeclaredConstructors().length; i++) {
|
|
||||||
if (clazz.getDeclaredConstructors()[i].getParameterTypes().length < 1) {
|
|
||||||
c = clazz.getDeclaredConstructors()[i];
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return c;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
public static Field[] getAllFields(Class<?> clazz) {
|
|
||||||
// Keep backing up the inheritance hierarchy.
|
|
||||||
Class<?> targetClass = clazz;
|
|
||||||
List<Field> fields = new ArrayList<>();
|
|
||||||
|
|
||||||
do {
|
|
||||||
fields.addAll(Arrays.asList(targetClass.getDeclaredFields()));
|
|
||||||
targetClass = targetClass.getSuperclass();
|
|
||||||
} while (targetClass != null && targetClass != Object.class);
|
|
||||||
|
|
||||||
return fields.toArray(new Field[fields.size()]);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Sets the properties of the given object
|
|
||||||
* @param obj the object o set
|
|
||||||
* @param props the properties to set
|
|
||||||
*/
|
|
||||||
public static void setProperties(Object obj, Properties props) throws Exception {
|
|
||||||
for (Field field : obj.getClass().getDeclaredFields()) {
|
|
||||||
field.setAccessible(true);
|
|
||||||
if (props.containsKey(field.getName())) {
|
|
||||||
set(field, obj, props.getProperty(field.getName()));
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/* sets a field with a fairly basic strategy */
|
|
||||||
private static void set(Field field, Object obj, String value) throws Exception {
|
|
||||||
Class<?> clazz = field.getType();
|
|
||||||
field.setAccessible(true);
|
|
||||||
if (clazz.equals(Double.class) || clazz.equals(double.class)) {
|
|
||||||
double val = Double.valueOf(value);
|
|
||||||
field.set(obj, val);
|
|
||||||
} else if (clazz.equals(String.class)) {
|
|
||||||
field.set(obj, value);
|
|
||||||
} else if (clazz.equals(Integer.class) || clazz.equals(int.class)) {
|
|
||||||
int val = Integer.parseInt(value);
|
|
||||||
field.set(obj, val);
|
|
||||||
} else if (clazz.equals(Float.class) || clazz.equals(float.class)) {
|
|
||||||
float f = Float.parseFloat(value);
|
|
||||||
field.set(obj, f);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Get fields as properties
|
|
||||||
* @param obj the object to get fields for
|
|
||||||
* @param clazzes the classes to use for reflection and properties.
|
|
||||||
* T
|
|
||||||
* @return the fields as properties
|
|
||||||
*/
|
|
||||||
public static Properties getFieldsAsProperties(Object obj, Class<?>[] clazzes) throws Exception {
|
|
||||||
Properties props = new Properties();
|
|
||||||
for (Field field : obj.getClass().getDeclaredFields()) {
|
|
||||||
if (Modifier.isStatic(field.getModifiers()))
|
|
||||||
continue;
|
|
||||||
field.setAccessible(true);
|
|
||||||
Class<?> type = field.getType();
|
|
||||||
if (clazzes == null || contains(type, clazzes)) {
|
|
||||||
Object val = field.get(obj);
|
|
||||||
if (val != null)
|
|
||||||
props.put(field.getName(), val.toString());
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return props;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
private static boolean contains(Class<?> test, Class<?>[] arr) {
|
|
||||||
for (Class<?> c : arr)
|
|
||||||
if (c.equals(test))
|
|
||||||
return true;
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
|
@ -1,135 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
|
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
|
|
||||||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
|
|
||||||
<parent>
|
|
||||||
<artifactId>deeplearning4j-parent</artifactId>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<version>1.0.0-SNAPSHOT</version>
|
|
||||||
</parent>
|
|
||||||
<modelVersion>4.0.0</modelVersion>
|
|
||||||
|
|
||||||
<artifactId>dl4j-perf</artifactId>
|
|
||||||
|
|
||||||
<name>dl4j-perf</name>
|
|
||||||
|
|
||||||
<properties>
|
|
||||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
|
||||||
<maven.compiler.source>1.7</maven.compiler.source>
|
|
||||||
<maven.compiler.target>1.7</maven.compiler.target>
|
|
||||||
</properties>
|
|
||||||
|
|
||||||
<dependencies>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.slf4j</groupId>
|
|
||||||
<artifactId>slf4j-api</artifactId>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.oshi</groupId>
|
|
||||||
<artifactId>oshi-json</artifactId>
|
|
||||||
<version>${oshi.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-nn</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>com.github.oshi</groupId>
|
|
||||||
<artifactId>oshi-core</artifactId>
|
|
||||||
<version>${oshi.version}</version>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>junit</groupId>
|
|
||||||
<artifactId>junit</artifactId>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.projectlombok</groupId>
|
|
||||||
<artifactId>lombok</artifactId>
|
|
||||||
<version>${lombok.version}</version>
|
|
||||||
<scope>provided</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>ch.qos.logback</groupId>
|
|
||||||
<artifactId>logback-classic</artifactId>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-datasets</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.deeplearning4j</groupId>
|
|
||||||
<artifactId>deeplearning4j-common-tests</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
<scope>test</scope>
|
|
||||||
</dependency>
|
|
||||||
</dependencies>
|
|
||||||
|
|
||||||
<build>
|
|
||||||
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
|
|
||||||
<plugins>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-clean-plugin</artifactId>
|
|
||||||
<version>3.0.0</version>
|
|
||||||
</plugin>
|
|
||||||
<!-- see http://maven.apache.org/ref/current/maven-core/default-bindings.html#Plugin_bindings_for_jar_packaging -->
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-resources-plugin</artifactId>
|
|
||||||
<version>3.0.2</version>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-compiler-plugin</artifactId>
|
|
||||||
<version>3.7.0</version>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-surefire-plugin</artifactId>
|
|
||||||
<version>2.20.1</version>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-jar-plugin</artifactId>
|
|
||||||
<version>3.0.2</version>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-install-plugin</artifactId>
|
|
||||||
<version>2.5.2</version>
|
|
||||||
</plugin>
|
|
||||||
<plugin>
|
|
||||||
<artifactId>maven-deploy-plugin</artifactId>
|
|
||||||
<version>2.8.2</version>
|
|
||||||
</plugin>
|
|
||||||
</plugins>
|
|
||||||
</pluginManagement>
|
|
||||||
</build>
|
|
||||||
|
|
||||||
<profiles>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-native</id>
|
|
||||||
</profile>
|
|
||||||
<profile>
|
|
||||||
<id>test-nd4j-cuda-10.2</id>
|
|
||||||
</profile>
|
|
||||||
</profiles>
|
|
||||||
</project>
|
|
|
@ -1,50 +0,0 @@
|
||||||
<!--~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
~ Copyright (c) 2015-2018 Skymind, Inc.
|
|
||||||
~
|
|
||||||
~ This program and the accompanying materials are made available under the
|
|
||||||
~ terms of the Apache License, Version 2.0 which is available at
|
|
||||||
~ https://www.apache.org/licenses/LICENSE-2.0.
|
|
||||||
~
|
|
||||||
~ Unless required by applicable law or agreed to in writing, software
|
|
||||||
~ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
||||||
~ WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
||||||
~ License for the specific language governing permissions and limitations
|
|
||||||
~ under the License.
|
|
||||||
~
|
|
||||||
~ SPDX-License-Identifier: Apache-2.0
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~-->
|
|
||||||
|
|
||||||
<configuration>
|
|
||||||
|
|
||||||
<appender name="FILE" class="ch.qos.logback.core.FileAppender">
|
|
||||||
<file>logs/application.log</file>
|
|
||||||
<encoder>
|
|
||||||
<pattern>%date - [%level] - from %logger in %thread
|
|
||||||
%n%message%n%xException%n</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<appender name="STDOUT" class="ch.qos.logback.core.ConsoleAppender">
|
|
||||||
<encoder>
|
|
||||||
<pattern> %logger{15} - %message%n%xException{5}
|
|
||||||
</pattern>
|
|
||||||
</encoder>
|
|
||||||
</appender>
|
|
||||||
|
|
||||||
<logger name="org.apache.catalina.core" level="DEBUG" />
|
|
||||||
<logger name="org.springframework" level="DEBUG" />
|
|
||||||
<logger name="org.deeplearning4j" level="INFO" />
|
|
||||||
<logger name="org.datavec" level="INFO" />
|
|
||||||
<logger name="org.nd4j" level="INFO" />
|
|
||||||
<logger name="opennlp.uima.util" level="OFF" />
|
|
||||||
<logger name="org.apache.uima" level="OFF" />
|
|
||||||
<logger name="org.cleartk" level="OFF" />
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
<root level="ERROR">
|
|
||||||
<appender-ref ref="STDOUT" />
|
|
||||||
<appender-ref ref="FILE" />
|
|
||||||
</root>
|
|
||||||
|
|
||||||
</configuration>
|
|
|
@ -140,8 +140,6 @@
|
||||||
<module>deeplearning4j-nearestneighbors-parent</module>
|
<module>deeplearning4j-nearestneighbors-parent</module>
|
||||||
<module>deeplearning4j-data</module>
|
<module>deeplearning4j-data</module>
|
||||||
<module>deeplearning4j-manifold</module>
|
<module>deeplearning4j-manifold</module>
|
||||||
<module>deeplearning4j-util</module>
|
|
||||||
<module>dl4j-perf</module>
|
|
||||||
<module>dl4j-integration-tests</module>
|
<module>dl4j-integration-tests</module>
|
||||||
<module>deeplearning4j-common</module>
|
<module>deeplearning4j-common</module>
|
||||||
<module>deeplearning4j-remote</module>
|
<module>deeplearning4j-remote</module>
|
||||||
|
|
|
@ -39,6 +39,7 @@ do
|
||||||
done
|
done
|
||||||
|
|
||||||
CHIP="${CHIP:-cpu}"
|
CHIP="${CHIP:-cpu}"
|
||||||
|
export GTEST_OUTPUT="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
|
||||||
|
|
||||||
# On Mac, make sure it can find libraries for GCC
|
# On Mac, make sure it can find libraries for GCC
|
||||||
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
|
export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/lib/gcc/6/:/usr/local/lib/gcc/5/
|
||||||
|
@ -47,8 +48,9 @@ export DYLD_LIBRARY_PATH=/usr/local/lib/gcc/8/:/usr/local/lib/gcc/7/:/usr/local/
|
||||||
if [ -n "$BUILD_PATH" ]; then
|
if [ -n "$BUILD_PATH" ]; then
|
||||||
if which cygpath; then
|
if which cygpath; then
|
||||||
BUILD_PATH=$(cygpath -p $BUILD_PATH)
|
BUILD_PATH=$(cygpath -p $BUILD_PATH)
|
||||||
|
export GTEST_OUTPUT="xml:'..\target\surefire-reports\TEST-${CHIP}-results.xml'"
|
||||||
fi
|
fi
|
||||||
export PATH="$PATH:$BUILD_PATH"
|
export PATH="$PATH:$BUILD_PATH"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests --gtest_output="xml:../target/surefire-reports/TEST-${CHIP}-results.xml"
|
../blasbuild/${CHIP}/tests_cpu/layers_tests/runtests
|
||||||
|
|
|
@ -217,16 +217,6 @@
|
||||||
<artifactId>commons-net</artifactId>
|
<artifactId>commons-net</artifactId>
|
||||||
<version>${commons-net.version}</version>
|
<version>${commons-net.version}</version>
|
||||||
</dependency>
|
</dependency>
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-buffer</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
|
||||||
<groupId>org.nd4j</groupId>
|
|
||||||
<artifactId>nd4j-context</artifactId>
|
|
||||||
<version>${project.version}</version>
|
|
||||||
</dependency>
|
|
||||||
<dependency>
|
<dependency>
|
||||||
<groupId>net.ericaro</groupId>
|
<groupId>net.ericaro</groupId>
|
||||||
<artifactId>neoitertools</artifactId>
|
<artifactId>neoitertools</artifactId>
|
||||||
|
@ -238,6 +228,16 @@
|
||||||
</exclusion>
|
</exclusion>
|
||||||
</exclusions>
|
</exclusions>
|
||||||
</dependency>
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.nd4j</groupId>
|
||||||
|
<artifactId>nd4j-common</artifactId>
|
||||||
|
<version>${project.version}</version>
|
||||||
|
</dependency>
|
||||||
|
<dependency>
|
||||||
|
<groupId>org.bytedeco</groupId>
|
||||||
|
<artifactId>javacpp</artifactId>
|
||||||
|
<version>${javacpp.version}</version>
|
||||||
|
</dependency>
|
||||||
</dependencies>
|
</dependencies>
|
||||||
|
|
||||||
<profiles>
|
<profiles>
|
||||||
|
|
|
@ -81,11 +81,6 @@ public class BroadcastMax extends BaseBroadcastOp {
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Max";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -81,11 +81,6 @@ public class BroadcastMin extends BaseBroadcastOp {
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Min";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -75,11 +75,6 @@ public class BroadcastMulOp extends BaseBroadcastOp {
|
||||||
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
throw new NoOpNameFoundException("No ONNX op name found for: " + getClass().getName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName() {
|
|
||||||
return "Mul";
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
|
|
|
@ -75,11 +75,6 @@ public class BroadcastSubOp extends BaseBroadcastOp {
|
||||||
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public String tensorflowName(){
|
|
||||||
throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
public List<SDVariable> doDiff(List<SDVariable> f1) {
|
||||||
return null;
|
return null;
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue