337 lines
11 KiB
Java
Raw Normal View History

2019-06-06 15:21:15 +03:00
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.deeplearning4j.optimize.listeners;
import lombok.AllArgsConstructor;
import lombok.Data;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.graph.ComputationGraph;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import java.io.Serializable;
import java.net.InetAddress;
import java.util.*;
/**
* WARNING: THIS LISTENER SHOULD ONLY BE USED FOR MANUAL TESTING PURPOSES<br>
* It intentionally causes various types of failures according to some criteria, in order to test the response
* to it.<br>
* This is useful for example in:
* (a) Testing Spark fault tolerance<br>
* (b) Testing OOM exception crash dump information<br>
* Generally it should not be used in unit tests either, depending on how it is configured.<br>
* <br>
* Two aspects need to be configured to use this listener:
* 1. If/when the "failure" should be triggered - via FailureTrigger classes<br>
* 2. The type of failure when triggered - via FailureMode enum<br>
* <br>
* To specify if/when a failure should be triggered, use a {@link FailureTrigger} instance. Some built-in ones
* are provided, random probability, time since initialized, username, and iteration/epoch count.
* <br>
* Types of failures available:<br>
* - OOM (allocate large arrays in loop until OOM).<br>
* - System.exit(1)<br>
* - IllegalStateException<br>
* - Infinite sleep<br>
*
* @author Alex Black
*/
@Slf4j
public class FailureTestingListener implements TrainingListener, Serializable {
public enum FailureMode {OOM, SYSTEM_EXIT_1, ILLEGAL_STATE, INFINITE_SLEEP}
public enum CallType {ANY, EPOCH_START, EPOCH_END, FORWARD_PASS, GRADIENT_CALC, BACKWARD_PASS, ITER_DONE}
private final FailureTrigger trigger;
private final FailureMode failureMode;
public FailureTestingListener(@NonNull FailureMode mode, @NonNull FailureTrigger trigger){
this.trigger = trigger;
this.failureMode = mode;
}
@Override
public void iterationDone(Model model, int iteration, int epoch) {
call(CallType.ITER_DONE, model);
}
@Override
public void onEpochStart(Model model) {
call(CallType.EPOCH_START, model);
}
@Override
public void onEpochEnd(Model model) {
call(CallType.EPOCH_END, model);
}
@Override
public void onForwardPass(Model model, List<INDArray> activations) {
call(CallType.FORWARD_PASS, model);
}
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
call(CallType.FORWARD_PASS, model);
}
@Override
public void onGradientCalculation(Model model) {
call(CallType.GRADIENT_CALC, model);
}
@Override
public void onBackwardPass(Model model) {
call(CallType.BACKWARD_PASS, model);
}
protected void call(CallType callType, Model model){
if(!trigger.initialized()){
trigger.initialize();
}
int iter;
int epoch;
if(model instanceof MultiLayerNetwork){
iter = ((MultiLayerNetwork) model).getIterationCount();
epoch = ((MultiLayerNetwork) model).getEpochCount();
} else {
iter = ((ComputationGraph) model).getIterationCount();
epoch = ((ComputationGraph) model).getEpochCount();
}
boolean triggered = trigger.triggerFailure(callType, iter, epoch, model);
if(triggered){
log.error("*** FailureTestingListener was triggered on iteration {}, epoch {} - Failure mode is set to {} ***",
iter, epoch, failureMode);
switch (failureMode){
case OOM:
List<INDArray> list = new ArrayList<>();
while(true){
INDArray arr = Nd4j.createUninitialized(1_000_000_000);
list.add(arr);
}
//break;
case SYSTEM_EXIT_1:
log.error("Exiting due to FailureTestingListener triggering - calling System.exit(1)");
System.exit(1);
break;
case ILLEGAL_STATE:
log.error("Throwing new IllegalStateException due to FailureTestingListener triggering");
throw new IllegalStateException("FailureTestListener was triggered with failure mode " + failureMode
+ " - iteration " + iter + ", epoch " + epoch);
case INFINITE_SLEEP:
while(true){
try {
Thread.sleep(10000);
} catch (InterruptedException e){
//Ignore
}
}
default:
throw new RuntimeException("Unknown enum value: " + failureMode);
}
}
}
@Data
public static abstract class FailureTrigger implements Serializable {
private boolean initialized = false;
/**
* If true: trigger the failure. If false: don't trigger failure
* @param callType Type of call
* @param iteration Iteration number
* @param epoch Epoch number
* @param model Model
* @return
*/
public abstract boolean triggerFailure(CallType callType, int iteration, int epoch, Model model);
public boolean initialized(){
return initialized;
}
public void initialize(){
this.initialized = true;
}
}
@AllArgsConstructor
public static class And extends FailureTrigger{
protected List<FailureTrigger> triggers;
public And(FailureTrigger... triggers){
this.triggers = Arrays.asList(triggers);
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
boolean b = true;
for(FailureTrigger ft : triggers)
b &= ft.triggerFailure(callType, iteration, epoch, model);
return b;
}
@Override
public void initialize(){
super.initialize();
for(FailureTrigger ft : triggers)
ft.initialize();
}
}
public static class Or extends And {
public Or(FailureTrigger... triggers) {
super(triggers);
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
boolean b = false;
for(FailureTrigger ft : triggers)
b |= ft.triggerFailure(callType, iteration, epoch, model);
return b;
}
}
@Data
public static class RandomProb extends FailureTrigger {
private final CallType callType;
private final double probability;
private Random rng;
public RandomProb(CallType callType, double probability){
this.callType = callType;
this.probability = probability;
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
return (this.callType == CallType.ANY || callType == this.callType) && rng.nextDouble() < probability;
}
@Override
public void initialize(){
super.initialize();
this.rng = new Random();
}
}
@Data
public static class TimeSinceInitializedTrigger extends FailureTrigger {
private final long msSinceInit;
private long initTime;
public TimeSinceInitializedTrigger(long msSinceInit){
this.msSinceInit = msSinceInit;
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
return (System.currentTimeMillis() - initTime) > msSinceInit;
}
@Override
public void initialize(){
super.initialize();
this.initTime = System.currentTimeMillis();
}
}
@Data
public static class UserNameTrigger extends FailureTrigger {
private final String userName;
private boolean shouldFail = false;
public UserNameTrigger(@NonNull String userName) {
this.userName = userName;
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
return shouldFail;
}
@Override
public void initialize(){
super.initialize();
shouldFail = this.userName.equalsIgnoreCase(System.getProperty("user.name"));
}
}
//System.out.println("Hostname: " + InetAddress.getLocalHost().getHostName());
@Data
public static class HostNameTrigger extends FailureTrigger{
private final String hostName;
private boolean shouldFail = false;
public HostNameTrigger(@NonNull String hostName) {
this.hostName = hostName;
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
return shouldFail;
}
@Override
public void initialize(){
super.initialize();
try {
String hostname = InetAddress.getLocalHost().getHostName();
log.info("FailureTestingListere hostname: {}", hostname);
shouldFail = this.hostName.equalsIgnoreCase(hostname);
} catch (Exception e){
throw new RuntimeException(e);
}
}
}
@Data
public static class IterationEpochTrigger extends FailureTrigger {
private final boolean isEpoch;
private final int count;
public IterationEpochTrigger(boolean isEpoch, int count){
this.isEpoch = isEpoch;
this.count = count;
}
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
return (isEpoch && epoch == count) || (!isEpoch && iteration == count);
}
}
}