337 lines
11 KiB
Java
337 lines
11 KiB
Java
|
/*******************************************************************************
|
||
|
* Copyright (c) 2015-2018 Skymind, Inc.
|
||
|
*
|
||
|
* This program and the accompanying materials are made available under the
|
||
|
* terms of the Apache License, Version 2.0 which is available at
|
||
|
* https://www.apache.org/licenses/LICENSE-2.0.
|
||
|
*
|
||
|
* Unless required by applicable law or agreed to in writing, software
|
||
|
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||
|
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||
|
* License for the specific language governing permissions and limitations
|
||
|
* under the License.
|
||
|
*
|
||
|
* SPDX-License-Identifier: Apache-2.0
|
||
|
******************************************************************************/
|
||
|
|
||
|
package org.deeplearning4j.optimize.listeners;
|
||
|
|
||
|
import lombok.AllArgsConstructor;
|
||
|
import lombok.Data;
|
||
|
import lombok.NonNull;
|
||
|
import lombok.extern.slf4j.Slf4j;
|
||
|
import org.deeplearning4j.nn.api.Model;
|
||
|
import org.deeplearning4j.nn.graph.ComputationGraph;
|
||
|
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
|
||
|
import org.deeplearning4j.optimize.api.TrainingListener;
|
||
|
import org.nd4j.linalg.api.ndarray.INDArray;
|
||
|
import org.nd4j.linalg.factory.Nd4j;
|
||
|
|
||
|
import java.io.Serializable;
|
||
|
import java.net.InetAddress;
|
||
|
import java.util.*;
|
||
|
|
||
|
/**
|
||
|
* WARNING: THIS LISTENER SHOULD ONLY BE USED FOR MANUAL TESTING PURPOSES<br>
|
||
|
* It intentionally causes various types of failures according to some criteria, in order to test the response
|
||
|
* to it.<br>
|
||
|
* This is useful for example in:
|
||
|
* (a) Testing Spark fault tolerance<br>
|
||
|
* (b) Testing OOM exception crash dump information<br>
|
||
|
* Generally it should not be used in unit tests either, depending on how it is configured.<br>
|
||
|
* <br>
|
||
|
* Two aspects need to be configured to use this listener:
|
||
|
* 1. If/when the "failure" should be triggered - via FailureTrigger classes<br>
|
||
|
* 2. The type of failure when triggered - via FailureMode enum<br>
|
||
|
* <br>
|
||
|
* To specify if/when a failure should be triggered, use a {@link FailureTrigger} instance. Some built-in ones
|
||
|
* are provided, random probability, time since initialized, username, and iteration/epoch count.
|
||
|
* <br>
|
||
|
* Types of failures available:<br>
|
||
|
* - OOM (allocate large arrays in loop until OOM).<br>
|
||
|
* - System.exit(1)<br>
|
||
|
* - IllegalStateException<br>
|
||
|
* - Infinite sleep<br>
|
||
|
*
|
||
|
* @author Alex Black
|
||
|
*/
|
||
|
@Slf4j
|
||
|
public class FailureTestingListener implements TrainingListener, Serializable {
|
||
|
|
||
|
public enum FailureMode {OOM, SYSTEM_EXIT_1, ILLEGAL_STATE, INFINITE_SLEEP}
|
||
|
public enum CallType {ANY, EPOCH_START, EPOCH_END, FORWARD_PASS, GRADIENT_CALC, BACKWARD_PASS, ITER_DONE}
|
||
|
|
||
|
private final FailureTrigger trigger;
|
||
|
private final FailureMode failureMode;
|
||
|
|
||
|
public FailureTestingListener(@NonNull FailureMode mode, @NonNull FailureTrigger trigger){
|
||
|
this.trigger = trigger;
|
||
|
this.failureMode = mode;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void iterationDone(Model model, int iteration, int epoch) {
|
||
|
call(CallType.ITER_DONE, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onEpochStart(Model model) {
|
||
|
call(CallType.EPOCH_START, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onEpochEnd(Model model) {
|
||
|
call(CallType.EPOCH_END, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onForwardPass(Model model, List<INDArray> activations) {
|
||
|
call(CallType.FORWARD_PASS, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onForwardPass(Model model, Map<String, INDArray> activations) {
|
||
|
call(CallType.FORWARD_PASS, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onGradientCalculation(Model model) {
|
||
|
call(CallType.GRADIENT_CALC, model);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void onBackwardPass(Model model) {
|
||
|
call(CallType.BACKWARD_PASS, model);
|
||
|
}
|
||
|
|
||
|
protected void call(CallType callType, Model model){
|
||
|
if(!trigger.initialized()){
|
||
|
trigger.initialize();
|
||
|
}
|
||
|
|
||
|
int iter;
|
||
|
int epoch;
|
||
|
if(model instanceof MultiLayerNetwork){
|
||
|
iter = ((MultiLayerNetwork) model).getIterationCount();
|
||
|
epoch = ((MultiLayerNetwork) model).getEpochCount();
|
||
|
} else {
|
||
|
iter = ((ComputationGraph) model).getIterationCount();
|
||
|
epoch = ((ComputationGraph) model).getEpochCount();
|
||
|
}
|
||
|
boolean triggered = trigger.triggerFailure(callType, iter, epoch, model);
|
||
|
|
||
|
if(triggered){
|
||
|
log.error("*** FailureTestingListener was triggered on iteration {}, epoch {} - Failure mode is set to {} ***",
|
||
|
iter, epoch, failureMode);
|
||
|
switch (failureMode){
|
||
|
case OOM:
|
||
|
List<INDArray> list = new ArrayList<>();
|
||
|
while(true){
|
||
|
INDArray arr = Nd4j.createUninitialized(1_000_000_000);
|
||
|
list.add(arr);
|
||
|
}
|
||
|
//break;
|
||
|
case SYSTEM_EXIT_1:
|
||
|
log.error("Exiting due to FailureTestingListener triggering - calling System.exit(1)");
|
||
|
System.exit(1);
|
||
|
break;
|
||
|
case ILLEGAL_STATE:
|
||
|
log.error("Throwing new IllegalStateException due to FailureTestingListener triggering");
|
||
|
throw new IllegalStateException("FailureTestListener was triggered with failure mode " + failureMode
|
||
|
+ " - iteration " + iter + ", epoch " + epoch);
|
||
|
case INFINITE_SLEEP:
|
||
|
while(true){
|
||
|
try {
|
||
|
Thread.sleep(10000);
|
||
|
} catch (InterruptedException e){
|
||
|
//Ignore
|
||
|
}
|
||
|
}
|
||
|
default:
|
||
|
throw new RuntimeException("Unknown enum value: " + failureMode);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
@Data
|
||
|
public static abstract class FailureTrigger implements Serializable {
|
||
|
|
||
|
private boolean initialized = false;
|
||
|
|
||
|
/**
|
||
|
* If true: trigger the failure. If false: don't trigger failure
|
||
|
* @param callType Type of call
|
||
|
* @param iteration Iteration number
|
||
|
* @param epoch Epoch number
|
||
|
* @param model Model
|
||
|
* @return
|
||
|
*/
|
||
|
public abstract boolean triggerFailure(CallType callType, int iteration, int epoch, Model model);
|
||
|
|
||
|
public boolean initialized(){
|
||
|
return initialized;
|
||
|
}
|
||
|
|
||
|
public void initialize(){
|
||
|
this.initialized = true;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@AllArgsConstructor
|
||
|
public static class And extends FailureTrigger{
|
||
|
|
||
|
protected List<FailureTrigger> triggers;
|
||
|
|
||
|
public And(FailureTrigger... triggers){
|
||
|
this.triggers = Arrays.asList(triggers);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
boolean b = true;
|
||
|
for(FailureTrigger ft : triggers)
|
||
|
b &= ft.triggerFailure(callType, iteration, epoch, model);
|
||
|
return b;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void initialize(){
|
||
|
super.initialize();
|
||
|
for(FailureTrigger ft : triggers)
|
||
|
ft.initialize();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
public static class Or extends And {
|
||
|
public Or(FailureTrigger... triggers) {
|
||
|
super(triggers);
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
boolean b = false;
|
||
|
for(FailureTrigger ft : triggers)
|
||
|
b |= ft.triggerFailure(callType, iteration, epoch, model);
|
||
|
return b;
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@Data
|
||
|
public static class RandomProb extends FailureTrigger {
|
||
|
|
||
|
private final CallType callType;
|
||
|
private final double probability;
|
||
|
private Random rng;
|
||
|
|
||
|
public RandomProb(CallType callType, double probability){
|
||
|
this.callType = callType;
|
||
|
this.probability = probability;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
return (this.callType == CallType.ANY || callType == this.callType) && rng.nextDouble() < probability;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void initialize(){
|
||
|
super.initialize();
|
||
|
this.rng = new Random();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
@Data
|
||
|
public static class TimeSinceInitializedTrigger extends FailureTrigger {
|
||
|
|
||
|
private final long msSinceInit;
|
||
|
private long initTime;
|
||
|
|
||
|
public TimeSinceInitializedTrigger(long msSinceInit){
|
||
|
this.msSinceInit = msSinceInit;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
return (System.currentTimeMillis() - initTime) > msSinceInit;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void initialize(){
|
||
|
super.initialize();
|
||
|
this.initTime = System.currentTimeMillis();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@Data
|
||
|
public static class UserNameTrigger extends FailureTrigger {
|
||
|
private final String userName;
|
||
|
private boolean shouldFail = false;
|
||
|
|
||
|
public UserNameTrigger(@NonNull String userName) {
|
||
|
this.userName = userName;
|
||
|
}
|
||
|
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
return shouldFail;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void initialize(){
|
||
|
super.initialize();
|
||
|
shouldFail = this.userName.equalsIgnoreCase(System.getProperty("user.name"));
|
||
|
}
|
||
|
}
|
||
|
//System.out.println("Hostname: " + InetAddress.getLocalHost().getHostName());
|
||
|
|
||
|
@Data
|
||
|
public static class HostNameTrigger extends FailureTrigger{
|
||
|
private final String hostName;
|
||
|
private boolean shouldFail = false;
|
||
|
|
||
|
public HostNameTrigger(@NonNull String hostName) {
|
||
|
this.hostName = hostName;
|
||
|
}
|
||
|
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
return shouldFail;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public void initialize(){
|
||
|
super.initialize();
|
||
|
try {
|
||
|
String hostname = InetAddress.getLocalHost().getHostName();
|
||
|
log.info("FailureTestingListere hostname: {}", hostname);
|
||
|
shouldFail = this.hostName.equalsIgnoreCase(hostname);
|
||
|
} catch (Exception e){
|
||
|
throw new RuntimeException(e);
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
@Data
|
||
|
public static class IterationEpochTrigger extends FailureTrigger {
|
||
|
|
||
|
private final boolean isEpoch;
|
||
|
private final int count;
|
||
|
|
||
|
public IterationEpochTrigger(boolean isEpoch, int count){
|
||
|
this.isEpoch = isEpoch;
|
||
|
this.count = count;
|
||
|
}
|
||
|
|
||
|
@Override
|
||
|
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
|
||
|
return (isEpoch && epoch == count) || (!isEpoch && iteration == count);
|
||
|
}
|
||
|
}
|
||
|
|
||
|
|
||
|
}
|