* AffinityManager changes Signed-off-by: raver119 <raver119@gmail.com> * build fixes Signed-off-by: raver119 <raver119@gmail.com>
987 lines
37 KiB
Java
987 lines
37 KiB
Java
/*******************************************************************************
|
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
|
*
|
|
* This program and the accompanying materials are made available under the
|
|
* terms of the Apache License, Version 2.0 which is available at
|
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
|
* License for the specific language governing permissions and limitations
|
|
* under the License.
|
|
*
|
|
* SPDX-License-Identifier: Apache-2.0
|
|
******************************************************************************/
|
|
|
|
package org.deeplearning4j.parallelism;
|
|
|
|
import lombok.*;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.deeplearning4j.api.storage.StatsStorageRouter;
|
|
import org.deeplearning4j.api.storage.listener.RoutingIterationListener;
|
|
import org.nd4j.linalg.dataset.AsyncDataSetIterator;;
|
|
import org.nd4j.linalg.dataset.AsyncMultiDataSetIterator;
|
|
import org.deeplearning4j.datasets.iterator.DummyBlockDataSetIterator;
|
|
import org.deeplearning4j.datasets.iterator.DummyBlockMultiDataSetIterator;
|
|
import org.deeplearning4j.datasets.iterator.callbacks.InterleavedDataSetCallback;
|
|
import org.deeplearning4j.exception.DL4JInvalidConfigException;
|
|
import org.deeplearning4j.nn.api.Model;
|
|
import org.deeplearning4j.nn.api.Updater;
|
|
import org.deeplearning4j.nn.conf.WorkspaceMode;
|
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
|
import org.deeplearning4j.nn.updater.graph.ComputationGraphUpdater;
|
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
|
import org.deeplearning4j.optimize.listeners.SharedGradient;
|
|
import org.deeplearning4j.optimize.solvers.accumulation.EncodedGradientsAccumulator;
|
|
import org.deeplearning4j.optimize.solvers.accumulation.GradientsAccumulator;
|
|
import org.deeplearning4j.optimize.solvers.accumulation.Registerable;
|
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ResidualPostProcessor;
|
|
import org.deeplearning4j.optimize.solvers.accumulation.encoding.ThresholdAlgorithm;
|
|
import org.deeplearning4j.parallelism.factory.DefaultTrainerContext;
|
|
import org.deeplearning4j.parallelism.factory.SymmetricTrainerContext;
|
|
import org.deeplearning4j.parallelism.factory.TrainerContext;
|
|
import org.deeplearning4j.parallelism.trainer.Trainer;
|
|
import org.nd4j.base.Preconditions;
|
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
|
import org.nd4j.linalg.dataset.api.DataSet;
|
|
import org.nd4j.linalg.dataset.api.MultiDataSet;
|
|
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
|
|
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
|
|
import org.nd4j.linalg.exception.ND4JIllegalStateException;
|
|
import org.nd4j.linalg.factory.Nd4j;
|
|
import org.nd4j.linalg.function.Supplier;
|
|
|
|
import java.util.*;
|
|
import java.util.concurrent.Executors;
|
|
import java.util.concurrent.LinkedBlockingQueue;
|
|
import java.util.concurrent.ThreadFactory;
|
|
import java.util.concurrent.ThreadPoolExecutor;
|
|
import java.util.concurrent.atomic.AtomicBoolean;
|
|
import java.util.concurrent.atomic.AtomicInteger;
|
|
import java.util.concurrent.atomic.AtomicLong;
|
|
|
|
/**
|
|
* This is simple data-parallel wrapper
|
|
* suitable for multi-cpu/multi-gpu environments.
|
|
*
|
|
* PLEASE NOTE: This implementation is NOT NUMA-aware.
|
|
*
|
|
* @author raver119@gmail.com
|
|
*/
|
|
// TODO: We want this thing to be NUMA-aware in foreseeable future
|
|
@Slf4j
|
|
@Data
|
|
public class ParallelWrapper implements AutoCloseable {
|
|
public enum TrainingMode {
|
|
/**
|
|
* Averaging every X epochs will be applied
|
|
*/
|
|
AVERAGING,
|
|
|
|
/**
|
|
* Models within ParallelWrapper instance will share gradients updates
|
|
*/
|
|
SHARED_GRADIENTS,
|
|
|
|
/**
|
|
* This option assumes use of GradientsAccumulator with any MessageHandler
|
|
*/
|
|
CUSTOM,
|
|
}
|
|
|
|
protected Supplier<INDArray> modelParamsSupplier;
|
|
protected Supplier<INDArray> updaterParamsSupplier;
|
|
|
|
protected AtomicBoolean exceptionEncountered;
|
|
protected Throwable exception;
|
|
protected final String uuid = java.util.UUID.randomUUID().toString();
|
|
protected Model model;
|
|
protected int workers = 2;
|
|
protected int prefetchSize = 2;
|
|
protected int averagingFrequency = 1;
|
|
protected Trainer[] zoo;
|
|
protected TrainerContext trainerContext;
|
|
protected AtomicLong iterationsCounter = new AtomicLong(0);
|
|
protected boolean reportScore = false;
|
|
protected boolean averageUpdaters = true;
|
|
protected boolean legacyAveraging = false;
|
|
protected boolean wasAveraged = false;
|
|
protected AtomicBoolean stopFit = new AtomicBoolean(false);
|
|
protected List<TrainingListener> listeners = new ArrayList<>();
|
|
protected StatsStorageRouter storageRouter;
|
|
protected boolean isMQ;
|
|
protected WorkspaceMode workspaceMode;
|
|
protected Object[] trainerContextArgs;
|
|
protected boolean debug = false;
|
|
|
|
protected ThreadPoolExecutor executorService;
|
|
|
|
protected final AtomicInteger workerCounter = new AtomicInteger(0);
|
|
@Getter
|
|
@Setter
|
|
protected GradientsAccumulator gradientsAccumulator;
|
|
|
|
// log uncaught exceptions
|
|
Thread.UncaughtExceptionHandler handler = new Thread.UncaughtExceptionHandler() {
|
|
public void uncaughtException(Thread th, Throwable ex) {
|
|
log.error("Uncaught exception: " + ex);
|
|
ex.printStackTrace();
|
|
if(exceptionEncountered != null){
|
|
exceptionEncountered.set(true);
|
|
exception = ex;
|
|
}
|
|
}
|
|
};
|
|
|
|
protected ParallelWrapper(Model model, int workers, int prefetchSize) {
|
|
this.model = model;
|
|
this.workers = workers;
|
|
this.prefetchSize = prefetchSize;
|
|
|
|
if (this.model instanceof MultiLayerNetwork) {
|
|
((MultiLayerNetwork) this.model).getUpdater();
|
|
} else if (this.model instanceof ComputationGraph) {
|
|
((ComputationGraph) this.model).getUpdater();
|
|
}
|
|
}
|
|
|
|
protected void init() {
|
|
workerCounter.set(0);
|
|
this.executorService = (ThreadPoolExecutor) Executors.newFixedThreadPool(workers, new ThreadFactory() {
|
|
@Override
|
|
public Thread newThread(@NonNull final Runnable r) {
|
|
final int cThread = workerCounter.getAndIncrement();
|
|
|
|
Thread t = new Thread(new Runnable() {
|
|
@Override
|
|
public void run() {
|
|
Nd4j.getAffinityManager().unsafeSetDevice(cThread % Nd4j.getAffinityManager().getNumberOfDevices());
|
|
r.run();
|
|
}
|
|
});
|
|
|
|
t.setName("ParallelWrapper training thread " + cThread);
|
|
t.setDaemon(true);
|
|
t.setUncaughtExceptionHandler(handler);
|
|
|
|
return t;
|
|
}
|
|
});
|
|
}
|
|
|
|
@Override
|
|
public void close() throws Exception {
|
|
if (zoo != null) {
|
|
for (int i = 0; i < zoo.length; i++) {
|
|
if (zoo[i] != null)
|
|
zoo[i].shutdown();
|
|
}
|
|
zoo = null;
|
|
}
|
|
|
|
if (executorService != null) {
|
|
executorService.shutdown();
|
|
executorService = null;
|
|
}
|
|
|
|
if (gradientsAccumulator != null)
|
|
gradientsAccumulator.reset();
|
|
}
|
|
|
|
/**
|
|
* This method causes all threads used for parallel training to stop
|
|
*/
|
|
public synchronized void shutdown() {
|
|
try {
|
|
close();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Will stop a fit operation from continuing to iterate.
|
|
*/
|
|
public void stopFit() {
|
|
stopFit.set(true);
|
|
}
|
|
|
|
/**
|
|
*
|
|
* @param source
|
|
*/
|
|
public synchronized void fit(@NonNull MultiDataSetIterator source) {
|
|
stopFit.set(false);
|
|
createZooIfNeccessary(true);
|
|
|
|
if (!source.hasNext() && source.resetSupported())
|
|
source.reset();
|
|
|
|
MultiDataSetIterator iterator = source;
|
|
if (prefetchSize > 0 && source.asyncSupported()) {
|
|
if (isMQ) {
|
|
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
|
|
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers,
|
|
Nd4j.getAffinityManager().getNumberOfDevices());
|
|
|
|
iterator = new AsyncMultiDataSetIterator(source, prefetchSize,
|
|
new LinkedBlockingQueue<>(prefetchSize * workers), true,
|
|
new InterleavedDataSetCallback(prefetchSize * 2));
|
|
} else
|
|
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
|
|
}
|
|
|
|
val locker = new AtomicInteger(0);
|
|
|
|
val blockWrapper = new DummyBlockMultiDataSetIterator(iterator);
|
|
|
|
var time1 = System.currentTimeMillis();
|
|
while (blockWrapper.hasAnything() && !stopFit.get()) {
|
|
if (modelParamsSupplier != null) {
|
|
val params = modelParamsSupplier.get();
|
|
if (params != null) {
|
|
if (zoo != null)
|
|
for (val z: zoo)
|
|
z.updateModelParams(params);
|
|
}
|
|
}
|
|
|
|
if (updaterParamsSupplier != null) {
|
|
val params = updaterParamsSupplier.get();
|
|
if (params != null) {
|
|
if (zoo != null)
|
|
for (val z: zoo)
|
|
z.updateUpdaterParams(params);
|
|
}
|
|
}
|
|
|
|
val dataSets = blockWrapper.next(workers);
|
|
long time2 = System.currentTimeMillis();
|
|
|
|
if (dataSets == null)
|
|
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
|
|
|
|
locker.set(dataSets.length);
|
|
|
|
/*
|
|
* if we're using registerable accumulator (i.e. we're on spark or cuda with gradients sharing),
|
|
* update it & notify about number of threads in this training round then
|
|
*/
|
|
if (gradientsAccumulator != null && gradientsAccumulator instanceof Registerable) {
|
|
((Registerable) gradientsAccumulator).registerConsumers(dataSets.length);
|
|
}
|
|
|
|
/*
|
|
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
|
|
*/
|
|
|
|
for (int pos = 0; pos < dataSets.length; pos++) {
|
|
zoo[pos].feedMultiDataSet(dataSets[pos], time2 - time1);
|
|
}
|
|
|
|
iterationsCounter.incrementAndGet();
|
|
|
|
/*
|
|
* if all workers are dispatched now, join till all are finished
|
|
*/
|
|
for (int pos = 0; pos < dataSets.length; pos++) {
|
|
zoo[pos].waitTillRunning();
|
|
}
|
|
|
|
//Nd4j.getMemoryManager().invokeGcOccasionally();
|
|
|
|
// optional averaging
|
|
if (zoo[0].averagingRequired() && iterationsCounter.get() % averagingFrequency == 0 ) {
|
|
/*
|
|
* average model, and propagate it to all workers
|
|
*/
|
|
|
|
double score = getScore(locker);
|
|
|
|
// averaging updaters state
|
|
averageUpdatersState(locker, score);
|
|
}
|
|
|
|
locker.set(0);
|
|
|
|
time1 = System.currentTimeMillis();
|
|
}
|
|
|
|
if (debug)
|
|
log.info("Stopping everyone...");
|
|
|
|
if (debug)
|
|
log.info("Shutting down iterator...");
|
|
|
|
if (prefetchSize > 0 && source.asyncSupported())
|
|
((AsyncMultiDataSetIterator) iterator).shutdown();
|
|
|
|
/*
|
|
// TODO: get rid of this code, 0 model is not replicated anyway
|
|
// now we transfer models back from workers
|
|
List<Model> models = new ArrayList<>();
|
|
for (int i = 0; i < zoo.length; i++) {
|
|
models.add(zoo[0].getModel());
|
|
}
|
|
|
|
// actual transfer code depends on trainer
|
|
trainerContext.finalizeTraining(model, models.toArray(new Model[0]));
|
|
*/
|
|
try {
|
|
close();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
// sanity checks, or the dataset may never average
|
|
if (!wasAveraged)
|
|
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
|
|
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
|
|
|
|
log.debug("Iterations passed: {}", iterationsCounter.get());
|
|
// iterationsCounter.set(0);
|
|
}
|
|
|
|
private double getScore(AtomicInteger locker) {
|
|
wasAveraged = true;
|
|
double score = 0.0;
|
|
|
|
List<INDArray> params = new ArrayList<>();
|
|
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
|
|
params.add(zoo[cnt].getModel().params());
|
|
score += zoo[cnt].getModel().score();
|
|
}
|
|
|
|
Nd4j.averageAndPropagate(null, params);
|
|
|
|
|
|
score /= Math.min(workers, locker.get());
|
|
|
|
// TODO: improve this
|
|
if (reportScore)
|
|
log.info("Averaged score: " + score);
|
|
|
|
return score;
|
|
}
|
|
|
|
private void averageUpdatersState(AtomicInteger locker, double score) {
|
|
// averaging updaters state
|
|
if (model instanceof MultiLayerNetwork) {
|
|
if (averageUpdaters) {
|
|
Updater updater = ((MultiLayerNetwork) model).getUpdater();
|
|
int batchSize = 0;
|
|
|
|
if (updater != null && updater.getStateViewArray() != null) {
|
|
List<INDArray> updaters = new ArrayList<>();
|
|
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
|
|
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
|
|
updaters.add(workerModel.getUpdater().getStateViewArray());
|
|
batchSize += workerModel.batchSize();
|
|
}
|
|
|
|
Nd4j.averageAndPropagate(null, updaters);
|
|
}
|
|
}
|
|
|
|
((MultiLayerNetwork) model).setScore(score);
|
|
} else if (model instanceof ComputationGraph) {
|
|
if (averageUpdaters) {
|
|
ComputationGraphUpdater updater = ((ComputationGraph) model).getUpdater();
|
|
int batchSize = 0;
|
|
|
|
if (updater != null && updater.getStateViewArray() != null) {
|
|
List<INDArray> updaters = new ArrayList<>();
|
|
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
|
|
ComputationGraph workerModel = (ComputationGraph) zoo[cnt].getModel();
|
|
updaters.add(workerModel.getUpdater().getStateViewArray());
|
|
batchSize += workerModel.batchSize();
|
|
}
|
|
Nd4j.averageAndPropagate(null, updaters);
|
|
}
|
|
}
|
|
|
|
((ComputationGraph) model).setScore(score);
|
|
}
|
|
}
|
|
|
|
|
|
/**
|
|
* This method allows you to specify trainingListeners for this model.
|
|
* Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead
|
|
* using {@link #setListeners(StatsStorageRouter, Collection)}
|
|
*
|
|
* @param listeners Listeners to set
|
|
*/
|
|
public void setListeners(@NonNull Collection<TrainingListener> listeners) {
|
|
setListeners(null, listeners);
|
|
}
|
|
|
|
/**
|
|
* This method allows you to specify trainingListeners for this model.
|
|
* Note that for listeners like StatsListener (that have state that will be sent somewhere), consider instead
|
|
* using {@link #setListeners(StatsStorageRouter, Collection)}
|
|
*
|
|
* @param listeners Listeners to set
|
|
*/
|
|
public void setListeners(@NonNull TrainingListener... listeners) {
|
|
setListeners(Arrays.asList(listeners));
|
|
}
|
|
|
|
/**
|
|
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
|
|
* that implement the {@link RoutingIterationListener} interface)
|
|
*
|
|
* @param statsStorage Stats storage router to place the results into
|
|
* @param listeners Listeners to set
|
|
*/
|
|
public void setListeners(StatsStorageRouter statsStorage, TrainingListener... listeners) {
|
|
setListeners(statsStorage, Arrays.asList(listeners));
|
|
}
|
|
|
|
/**
|
|
* Set the listeners, along with a StatsStorageRouter that the results will be shuffled to (in the case of any listeners
|
|
* that implement the {@link RoutingIterationListener} interface)
|
|
*
|
|
* @param statsStorage Stats storage router to place the results into
|
|
* @param listeners Listeners to set
|
|
*/
|
|
public void setListeners(StatsStorageRouter statsStorage, Collection<? extends TrainingListener> listeners) {
|
|
//Check if we have any RoutingIterationListener instances that need a StatsStorage implementation...
|
|
if (listeners != null) {
|
|
for (TrainingListener l : listeners) {
|
|
if (l instanceof RoutingIterationListener) {
|
|
RoutingIterationListener rl = (RoutingIterationListener) l;
|
|
if (statsStorage == null && rl.getStorageRouter() == null) {
|
|
log.warn("RoutingIterationListener provided without providing any StatsStorage instance. Iterator may not function without one. Listener: {}",
|
|
l);
|
|
}
|
|
}
|
|
}
|
|
|
|
this.listeners.addAll(listeners);
|
|
} else {
|
|
this.listeners.clear();
|
|
}
|
|
|
|
this.storageRouter = statsStorage;
|
|
}
|
|
|
|
/**
|
|
* This method will propagate gradients across all workers
|
|
*
|
|
* @param gradients
|
|
*/
|
|
public void broadcastGradients(SharedGradient gradients) {
|
|
// TODO: add implementation
|
|
/*
|
|
Basically all we want here is:
|
|
1) Ensure length matches parameters length
|
|
2) Ensure data is acessible from all devices somehow (i.e. it's in HOST-only mode
|
|
*/
|
|
/*
|
|
if (zoo[0] instanceof CommunicativeTrainer) {
|
|
for (int i = 0; i < zoo.length; i++) {
|
|
((CommunicativeTrainer) zoo[i]).enqueueGradient(gradients);
|
|
}
|
|
}
|
|
*/
|
|
}
|
|
|
|
|
|
/**
|
|
* This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
|
|
*
|
|
* @param source
|
|
*/
|
|
public synchronized void fit(@NonNull DataSetIterator source) {
|
|
log.info("Using workspaceMode {} for training", workspaceMode.name());
|
|
stopFit.set(false);
|
|
createZooIfNeccessary(false);
|
|
|
|
if (!source.hasNext() && source.resetSupported())
|
|
source.reset();
|
|
|
|
DataSetIterator iterator = source;
|
|
|
|
if (prefetchSize > 0 && source.asyncSupported()) {
|
|
log.info("Creating asynchronous prefetcher...");
|
|
if (isMQ) {
|
|
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
|
|
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers,
|
|
Nd4j.getAffinityManager().getNumberOfDevices());
|
|
|
|
iterator = new AsyncDataSetIterator(source, prefetchSize,
|
|
new LinkedBlockingQueue<>(prefetchSize * workers), true,
|
|
new InterleavedDataSetCallback(prefetchSize * 2));
|
|
|
|
} else
|
|
iterator = new AsyncDataSetIterator(source, prefetchSize);
|
|
}
|
|
|
|
|
|
val nanos = new ArrayList<Long>();
|
|
val locker = new AtomicInteger(0);
|
|
var time1 = System.currentTimeMillis();
|
|
log.info("Starting ParallelWrapper training round...");
|
|
long intcnt = 0;
|
|
|
|
val blockWrapper = new DummyBlockDataSetIterator(iterator);
|
|
|
|
while (blockWrapper.hasAnything() && !stopFit.get()) {
|
|
if (modelParamsSupplier != null) {
|
|
val params = modelParamsSupplier.get();
|
|
if (params != null) {
|
|
if (zoo != null) {
|
|
log.info("Updating model parameters...");
|
|
for (val z:zoo) {
|
|
z.updateModelParams(params);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if (updaterParamsSupplier != null) {
|
|
val params = updaterParamsSupplier.get();
|
|
if (params != null) {
|
|
if (zoo != null) {
|
|
log.info("Updating updater parameters...");
|
|
for (val z:zoo) {
|
|
z.updateUpdaterParams(params);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
intcnt++;
|
|
val dataSets = blockWrapper.next(workers);
|
|
var time2 = System.currentTimeMillis();
|
|
var lastEtlTime = time2 - time1;
|
|
|
|
if (dataSets == null)
|
|
throw new ND4JIllegalStateException("You can't have NULL as DataSet");
|
|
|
|
if (zoo == null)
|
|
throw new IllegalStateException(
|
|
"ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
|
|
|
|
locker.set(dataSets.length);
|
|
/*
|
|
* if we're using registerable accumulator (i.e. we're on spark or cuda with gradients sharing),
|
|
* update it & notify about number of threads in this training round then
|
|
*/
|
|
if (gradientsAccumulator != null && gradientsAccumulator instanceof Registerable) {
|
|
((Registerable) gradientsAccumulator).registerConsumers(dataSets.length);
|
|
}
|
|
|
|
|
|
// feeding datasets
|
|
for (int pos = 0; pos < dataSets.length; pos++) {
|
|
if (debug)
|
|
log.info("Feeding dataset {} to worker {}", intcnt, pos);
|
|
|
|
zoo[pos].feedDataSet(dataSets[pos], lastEtlTime);
|
|
}
|
|
|
|
iterationsCounter.incrementAndGet();
|
|
|
|
|
|
// waiting till all threads are done
|
|
for (int pos = 0; pos < dataSets.length; pos++) {
|
|
try {
|
|
zoo[pos].waitTillRunning();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
|
|
// optional averaging
|
|
if (iterationsCounter.get() % averagingFrequency == 0 && zoo[0].averagingRequired() ) {
|
|
long timeA1 = System.currentTimeMillis();
|
|
|
|
// model averaging happens within
|
|
double score = getScore(locker);
|
|
|
|
// updaters averging happens within (if any)
|
|
averageUpdatersState(locker, score);
|
|
|
|
long timeA2 = System.currentTimeMillis();
|
|
if (reportScore)
|
|
log.info("Averaging time: {} ms", timeA2 - timeA1);
|
|
}
|
|
|
|
time1 = System.currentTimeMillis();
|
|
locker.set(0);
|
|
}
|
|
|
|
if (debug)
|
|
log.info("Stopping everyone...");
|
|
|
|
// ensure all threads stopped processing
|
|
for (int cnt = 0; cnt < workers; cnt++) {
|
|
try {
|
|
zoo[cnt].waitTillRunning();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
}
|
|
|
|
if (debug)
|
|
log.info("Shutting down iterator...");
|
|
|
|
if (prefetchSize > 0 && source.asyncSupported())
|
|
((AsyncDataSetIterator) iterator).shutdown();
|
|
|
|
try {
|
|
close();
|
|
} catch (Exception e) {
|
|
throw new RuntimeException(e);
|
|
}
|
|
|
|
if (debug)
|
|
log.info("Iterations passed: {}", iterationsCounter.get());
|
|
}
|
|
|
|
|
|
private void createZooIfNeccessary(boolean useMDS) {
|
|
if (zoo == null) {
|
|
trainerContext.init(model, trainerContextArgs);
|
|
|
|
zoo = new Trainer[workers];
|
|
int numDevices = Nd4j.getAffinityManager().getNumberOfDevices();
|
|
for (int cnt = 0; cnt < workers; cnt++) {
|
|
// we pass true here, to tell Trainer to use MultiDataSet queue for training
|
|
zoo[cnt] = trainerContext.create(this.uuid, cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(),
|
|
useMDS, this, workspaceMode, averagingFrequency);
|
|
|
|
/*
|
|
zoo[cnt].setUncaughtExceptionHandler(handler);
|
|
|
|
if (zoo[cnt] instanceof Thread) {
|
|
Nd4j.getAffinityManager().attachThreadToDevice((Thread) zoo[cnt], cnt % numDevices);
|
|
}
|
|
zoo[cnt].start();
|
|
*/
|
|
|
|
if (executorService == null)
|
|
init();
|
|
|
|
executorService.execute(zoo[cnt]);
|
|
}
|
|
}
|
|
}
|
|
|
|
public static class Builder<T extends Model> {
|
|
protected TrainingMode trainingMode = TrainingMode.AVERAGING;
|
|
protected T model;
|
|
protected int workers = Nd4j.getAffinityManager().getNumberOfDevices();
|
|
protected int prefetchSize = 16;
|
|
protected int averagingFrequency = 1;
|
|
protected boolean reportScore = false;
|
|
protected boolean averageUpdaters = true;
|
|
protected boolean legacyAveraging = true;
|
|
protected boolean isMQ = Nd4j.getAffinityManager().getNumberOfDevices() > 1;
|
|
protected TrainerContext trainerContext = new DefaultTrainerContext();
|
|
protected Object[] trainerContextArgs;
|
|
protected WorkspaceMode workspaceMode = WorkspaceMode.ENABLED;
|
|
protected Supplier<INDArray> modelParamsSupplier;
|
|
protected Supplier<INDArray> updaterParamsSupplier;
|
|
protected ThresholdAlgorithm thresholdAlgorithm;
|
|
protected ResidualPostProcessor residualPostProcessor;
|
|
|
|
protected GradientsAccumulator accumulator;
|
|
|
|
/**
|
|
* Transer context args are for calling a
|
|
* {@link TrainerContext} init method
|
|
* when {@link ParallelWrapper} starts training
|
|
* @param trainerContextArgs the args to use (maybe null)
|
|
* @return
|
|
*/
|
|
public Builder trainerContextArgs(Object... trainerContextArgs) {
|
|
this.trainerContextArgs = trainerContextArgs;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Specify a {@link TrainerContext}
|
|
* for the given {@link ParallelWrapper}
|
|
* instance.
|
|
* Defaults to {@link DefaultTrainerContext}
|
|
* otherwise
|
|
* @param trainerContext the trainer factory to use
|
|
* @return builder pattern
|
|
*/
|
|
public Builder trainerFactory(@NonNull TrainerContext trainerContext) {
|
|
this.trainerContext = trainerContext;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method allows to override model's WorkspaceMode configuration option
|
|
* @param mode
|
|
* @return
|
|
*/
|
|
public Builder workspaceMode(@NonNull WorkspaceMode mode) {
|
|
this.workspaceMode = mode;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method attaches supplier that'll probably provide model params update
|
|
*
|
|
* PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic
|
|
* @param supplier
|
|
* @return
|
|
*/
|
|
public Builder modelParamsSupplier(Supplier<INDArray> supplier) {
|
|
this.modelParamsSupplier = supplier;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method attaches supplier that'll probably provide updater params update
|
|
*
|
|
* PLEASE NOTE: This method is mostly used in Spark environment as part of fault tolerance logic
|
|
* @param supplier
|
|
* @return
|
|
*/
|
|
public Builder updaterParamsSupplier(Supplier<INDArray> supplier) {
|
|
this.updaterParamsSupplier = supplier;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Build ParallelWrapper for MultiLayerNetwork
|
|
*
|
|
* @param model
|
|
*/
|
|
public Builder(@NonNull T model) {
|
|
this.model = model;
|
|
}
|
|
|
|
/**
|
|
* This method allows to configure number of workers that'll be used for parallel training
|
|
*
|
|
* @param num
|
|
* @return
|
|
*/
|
|
public Builder workers(int num) {
|
|
if (num < 2)
|
|
throw new RuntimeException("Number of workers can't be lower then 2!");
|
|
|
|
this.workers = num;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Model averaging frequency.
|
|
*
|
|
* @param freq number of iterations between averaging
|
|
* @return
|
|
*/
|
|
public Builder averagingFrequency(int freq) {
|
|
if (freq < 0)
|
|
freq = 0;
|
|
|
|
this.averagingFrequency = freq;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method enables/disables updaters averaging.
|
|
*
|
|
* Default value: TRUE
|
|
*
|
|
* PLEASE NOTE: This method is suitable for debugging purposes mostly. So don't change default value, unless you're sure why you need it.
|
|
* PLEASE NOTE: This method is suitable for parameters averaging training only. For gradients sharing mechanism it'll be ignored
|
|
*
|
|
* @param reallyAverage
|
|
* @return
|
|
*/
|
|
public Builder averageUpdaters(boolean reallyAverage) {
|
|
this.averageUpdaters = reallyAverage;
|
|
return this;
|
|
}
|
|
|
|
|
|
/**
|
|
* Size of prefetch buffer that will be used for background data prefetching.
|
|
* Usually it's better to keep this value equal to the number of workers.
|
|
*
|
|
* Default value: 2
|
|
*
|
|
* @param size 0 to disable prefetching, any positive number
|
|
* @return
|
|
*/
|
|
public Builder prefetchBuffer(int size) {
|
|
if (size < 0)
|
|
size = 0;
|
|
|
|
this.prefetchSize = size;
|
|
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method allows you to specify training mode for this instance of PW.<br>
|
|
* 1) AVERAGING - stands for parameters averaging. Each X epochs weights and updaters state will be averaged across all models<br>
|
|
* 2) SHARED_GRADIENTS - stands for gradients sharing - more details available here: <a href="https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-intro">https://deeplearning4j.org/docs/latest/deeplearning4j-scaleout-intro</a><br>
|
|
* 3) CUSTOM - this method allows you to specify custom gradients accumulator, this giving you better control of configuration params for training.<br>
|
|
*
|
|
* @param mode
|
|
* @return
|
|
*/
|
|
public Builder trainingMode(@NonNull TrainingMode mode) {
|
|
this.trainingMode = mode;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method allows you to specify GradientsAccumulator instance to be used in this ParallelWrapper instance
|
|
*
|
|
* PLEASE NOTE: This method is applicable only to gradients sharing mechanics. If parameters averaging is used, accumulator will be ignored
|
|
*
|
|
* @param accumulator
|
|
* @return
|
|
*/
|
|
public Builder gradientsAccumulator(@NonNull GradientsAccumulator accumulator) {
|
|
this.accumulator = accumulator;
|
|
return this;
|
|
}
|
|
|
|
|
|
/**
|
|
* This method enables/disables averaged model score reporting
|
|
*
|
|
* @param reallyReport
|
|
* @return
|
|
*/
|
|
public Builder reportScoreAfterAveraging(boolean reallyReport) {
|
|
this.reportScore = reallyReport;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Set the threshold algorithm. Not used for single machine training (only for PW used in a distributed setting),
|
|
* and should not be set by users in most cases.
|
|
* @param thresholdAlgorithm Threshold algorithm to use
|
|
*/
|
|
public Builder thresholdAlgorithm(ThresholdAlgorithm thresholdAlgorithm){
|
|
this.thresholdAlgorithm = thresholdAlgorithm;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* Set the residual post processor algorithm. Not used for single machine training (only for PW used in a
|
|
* distributed setting), and should not be set by users in most cases.
|
|
* @param residualPostProcessor Residual post processor to use
|
|
*/
|
|
public Builder residualPostProcessor(ResidualPostProcessor residualPostProcessor){
|
|
this.residualPostProcessor = residualPostProcessor;
|
|
return this;
|
|
}
|
|
|
|
/**
|
|
* This method returns ParallelWrapper instance
|
|
*
|
|
* @return
|
|
*/
|
|
public ParallelWrapper build() {
|
|
ParallelWrapper wrapper = new ParallelWrapper(model, workers, prefetchSize);
|
|
wrapper.averagingFrequency = this.averagingFrequency;
|
|
wrapper.reportScore = this.reportScore;
|
|
wrapper.averageUpdaters = this.averageUpdaters;
|
|
wrapper.legacyAveraging = this.legacyAveraging;
|
|
wrapper.isMQ = this.isMQ;
|
|
wrapper.workspaceMode = this.workspaceMode;
|
|
wrapper.modelParamsSupplier = this.modelParamsSupplier;
|
|
wrapper.updaterParamsSupplier = this.updaterParamsSupplier;
|
|
|
|
|
|
switch (trainingMode) {
|
|
case AVERAGING: {
|
|
this.trainerContext = new DefaultTrainerContext();
|
|
this.accumulator = null;
|
|
log.info("Creating new AveragingTraining instance");
|
|
}
|
|
break;
|
|
case SHARED_GRADIENTS: {
|
|
Preconditions.checkState(thresholdAlgorithm != null, "Cannot use SHARED_GRADIENTS training mode without setting a threshold algorithm");
|
|
this.trainerContext = new SymmetricTrainerContext();
|
|
if (this.accumulator == null) {
|
|
log.info("Creating new GradientsAccumulator instance with threshold of [5e-4");
|
|
this.accumulator = new EncodedGradientsAccumulator(workers, thresholdAlgorithm, residualPostProcessor, false);
|
|
}
|
|
}
|
|
break;
|
|
case CUSTOM: {
|
|
this.trainerContext = new SymmetricTrainerContext();
|
|
if (this.accumulator == null)
|
|
throw new DL4JInvalidConfigException(
|
|
"Please specify GradientsAccumulator fo encoded gradients mode");
|
|
}
|
|
break;
|
|
default:
|
|
throw new UnsupportedOperationException("Unknown trainingMode: [" + trainingMode + "]");
|
|
}
|
|
|
|
wrapper.trainerContext = this.trainerContext;
|
|
wrapper.gradientsAccumulator = this.accumulator;
|
|
|
|
wrapper.init();
|
|
|
|
List<TrainingListener> modelListeners = null;
|
|
if (model instanceof MultiLayerNetwork) {
|
|
modelListeners = new ArrayList<>(((MultiLayerNetwork) model).getListeners());
|
|
model.setListeners(Collections.emptyList());
|
|
} else if (model instanceof ComputationGraph) {
|
|
modelListeners = new ArrayList<>(((ComputationGraph) model).getListeners());
|
|
model.setListeners(Collections.emptyList());
|
|
}
|
|
|
|
if (modelListeners != null && !modelListeners.isEmpty()) {
|
|
wrapper.setListeners(modelListeners);
|
|
}
|
|
|
|
return wrapper;
|
|
}
|
|
}
|
|
|
|
private static TrainingListener cloneListener(TrainingListener original) {
|
|
if (original instanceof RoutingIterationListener) {
|
|
return ((RoutingIterationListener) original).clone();
|
|
}
|
|
return original;
|
|
}
|
|
|
|
private void configureListeners(String workerUUID, Collection<TrainingListener> oldListeners,
|
|
Collection<TrainingListener> replicatedListeners) {
|
|
for (TrainingListener listener : oldListeners) {
|
|
TrainingListener l = cloneListener(listener);
|
|
|
|
if (l instanceof RoutingIterationListener) {
|
|
RoutingIterationListener rl = (RoutingIterationListener) l;
|
|
//We're assuming session ID is set by the original RoutingIterationListener constructor, which means
|
|
// it will be synced across all cloned instances
|
|
rl.setSessionID(((RoutingIterationListener) listener).getSessionID());
|
|
rl.setWorkerID(workerUUID);
|
|
|
|
StatsStorageRouter currentRouter = ((RoutingIterationListener) listener).getStorageRouter();
|
|
if (currentRouter != null) {
|
|
//User has set router on the listener/model, instead of via the
|
|
// setListeners(StatsStorageRouter, ...) method
|
|
rl.setStorageRouter(currentRouter);
|
|
} else {
|
|
rl.setStorageRouter(ParallelWrapper.this.storageRouter);
|
|
}
|
|
|
|
}
|
|
replicatedListeners.add(l);
|
|
}
|
|
}
|
|
}
|
|
|