MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-24 17:33:11 +10:00 committed by GitHub
parent f8364997c0
commit b85238a6df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 63 deletions

View File

@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length]; List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Throwable t = null;
try { try {
for (int i = 0; i <= stopIndex; i++) { for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]]; GraphVertex current = vertices[topologicalOrder[i]];
@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG) .with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
.build(); .build();
if(detachedInputs){ if (detachedInputs) {
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be //Sometimes (like: external errors use cases) we don't want the activations/inputs to be
// in a workspace // in a workspace
workspaceMgr.setScopedOutFor(ArrayType.INPUT); workspaceMgr.setScopedOutFor(ArrayType.INPUT);
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS); workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
} else { } else {
//Don't leverage out of async MultiDataSetIterator workspaces //Don't leverage out of async MultiDataSetIterator workspaces
if(features[0].isAttached()){ if (features[0].isAttached()) {
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId()); workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
} }
} }
@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
if (ArrayUtils.contains(layerIndexes, vIdx)) { if (ArrayUtils.contains(layerIndexes, vIdx)) {
isRequiredOutput = true; isRequiredOutput = true;
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
//Place activations in user-specified workspace //Place activations in user-specified workspace
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS); origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS); origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Open the relevant workspace for the activations. //Open the relevant workspace for the activations.
//Note that this will be closed only once the current vertex's activations have been consumed //Note that this will be closed only once the current vertex's activations have been consumed
MemoryWorkspace wsActivations = null; MemoryWorkspace wsActivations = null;
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS); wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
openActivationsWorkspaces.put(wsActivations, workspaceMgr); openActivationsWorkspaces.put(wsActivations, workspaceMgr);
} }
@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace //Note that because we're opening activation workspaces not in any defined order (i.e., workspace
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we // use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
// close these workspaces, the "current" workspace may be set to the incorrect one // close these workspaces, the "current" workspace may be set to the incorrect one
if(wsActivations != null ) if (wsActivations != null)
wsActivations.setPreviousWorkspace(initialWorkspace); wsActivations.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx]; int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) { if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
if (closeAtEndIteraton[closeableAt] == null) { if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>(); closeAtEndIteraton[closeableAt] = new ArrayList<>();
} }
@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
out = features[vIdx]; out = features[vIdx];
} else { } else {
if(fwdPassType == FwdPassType.STANDARD){ if (fwdPassType == FwdPassType.STANDARD) {
//Standard feed-forward case //Standard feed-forward case
out = current.doForward(train, workspaceMgr); out = current.doForward(train, workspaceMgr);
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
if (current.hasLayer()) { if (current.hasLayer()) {
//Layer //Layer
INDArray input = current.getInputs()[0]; INDArray input = current.getInputs()[0];
Layer l = current.getLayer(); Layer l = current.getLayer();
if (l instanceof RecurrentLayer) { if (l instanceof RecurrentLayer) {
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){ } else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying()); RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr); out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
} else if (l instanceof MultiLayerNetwork) { } else if (l instanceof MultiLayerNetwork) {
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input)); out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)"); validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
} }
if(inputsTo != null) { //Output vertices may not input to any other vertices if (inputsTo != null) { //Output vertices may not input to any other vertices
for (VertexIndices v : inputsTo) { for (VertexIndices v : inputsTo) {
//Note that we don't have to do anything special here: the activations are always detached in //Note that we don't have to do anything special here: the activations are always detached in
// this method // this method
@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
} }
if(clearLayerInputs) { if (clearLayerInputs) {
current.clear(); current.clear();
} }
if(isRequiredOutput){ if (isRequiredOutput) {
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out; outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
if(origWSAct != null){ if (origWSAct != null) {
//Reset the configuration, as we may reuse this workspace manager... //Reset the configuration, as we may reuse this workspace manager...
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf); workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
} }
@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Close any activations workspaces that we no longer require //Close any activations workspaces that we no longer require
//Note that activations workspaces can be closed only once the corresponding output activations have //Note that activations workspaces can be closed only once the corresponding output activations have
// been fully consumed // been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
} }
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
//Though if stopIndex < numLayers, some might still be open //Though if stopIndex < numLayers, some might still be open
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Edge case here: seems that scoping out can increase the tagScope of the current WS //Edge case here: seems that scoping out can increase the tagScope of the current WS
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient //and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
// number of times to actually close it, in all cases // number of times to actually close it, in all cases
ws.close(); try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
try{ Throwable t = null;
for(int i=topologicalOrder.length-1; i>= 0; i--){ try {
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
boolean hitFrozen = false; boolean hitFrozen = false;
GraphVertex current = vertices[topologicalOrder[i]]; GraphVertex current = vertices[topologicalOrder[i]];
int vIdx = current.getVertexIndex(); int vIdx = current.getVertexIndex();
String vertexName = current.getVertexName(); String vertexName = current.getVertexName();
if(traceLog){ if (traceLog) {
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
} }
//FIXME: make the frozen vertex feature extraction more flexible //FIXME: make the frozen vertex feature extraction more flexible
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){ if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
hitFrozen = true; hitFrozen = true;
} }
if (current.isInputVertex() || hitFrozen){ if (current.isInputVertex() || hitFrozen) {
//Close any activation gradient workspaces that we no longer require //Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations //Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed // gradients have been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
wsActivationGrads.setPreviousWorkspace(initialWorkspace); wsActivationGrads.setPreviousWorkspace(initialWorkspace);
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx]; int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
if(closeableAt >= 0) { if (closeableAt >= 0) {
if (closeAtEndIteraton[closeableAt] == null) { if (closeAtEndIteraton[closeableAt] == null) {
closeAtEndIteraton[closeableAt] = new ArrayList<>(); closeAtEndIteraton[closeableAt] = new ArrayList<>();
} }
@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
Pair<Gradient, INDArray[]> pair; Pair<Gradient, INDArray[]> pair;
INDArray[] epsilons; INDArray[] epsilons;
try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
pair = current.doBackward(truncatedBPTT, workspaceMgr); pair = current.doBackward(truncatedBPTT, workspaceMgr);
epsilons = pair.getSecond(); epsilons = pair.getSecond();
//Validate workspace location for the activation gradients: //Validate workspace location for the activation gradients:
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){ //validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
for (INDArray epsilon : epsilons) { for (INDArray epsilon : epsilons) {
if(epsilon != null) { if (epsilon != null) {
//May be null for EmbeddingLayer, etc //May be null for EmbeddingLayer, etc
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop"); validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
} }
@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
tempList.addFirst(new Triple<>(newName, entry.getValue(), tempList.addFirst(new Triple<>(newName, entry.getValue(),
g.flatteningOrderForVariable(origName))); g.flatteningOrderForVariable(origName)));
} }
for (Triple<String, INDArray, Character> t : tempList) for (Triple<String, INDArray, Character> triple : tempList)
gradients.addFirst(t); gradients.addFirst(triple);
} }
//Close any activation gradient workspaces that we no longer require //Close any activation gradient workspaces that we no longer require
//Note that activation gradient workspaces can be closed only once the corresponding activations //Note that activation gradient workspaces can be closed only once the corresponding activations
// gradients have been fully consumed // gradients have been fully consumed
if(closeAtEndIteraton[i] != null){ if (closeAtEndIteraton[i] != null) {
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){ for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
wsAct.close(); wsAct.close();
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct); LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
freeWorkspaceManagers.add(canNowReuse); freeWorkspaceManagers.add(canNowReuse);
@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
closeAtEndIteraton[i] = null; closeAtEndIteraton[i] = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){ for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
ws.close(); try{
ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }
//Now, add the gradients in the order we need them in for flattening (same as params order) //Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients); Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple<String, INDArray, Character> t : gradients) { for (Triple<String, INDArray, Character> tr : gradients) {
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird()); gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
} }
this.gradient = gradient; this.gradient = gradient;

View File

@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = 0; i <= layerIndex; i++) { for (int i = 0; i <= layerIndex; i++) {
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
if(traceLog){ if (traceLog) {
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
} }
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
//Hence: put inputs in working memory //Hence: put inputs in working memory
if(i == 0 && wsm != WorkspaceMode.NONE){ if (i == 0 && wsm != WorkspaceMode.NONE) {
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG); mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
} }
@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
temp.setPreviousWorkspace(initialWorkspace); temp.setPreviousWorkspace(initialWorkspace);
if(i == 0 && input.isAttached()){ if (i == 0 && input.isAttached()) {
//Don't leverage out of async DataSetIterator workspaces //Don't leverage out of async DataSetIterator workspaces
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId()); mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
} }
@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)"); validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
} }
if ( i == layerIndex ) { if (i == layerIndex) {
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){ if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
//Place activations in user-specified workspace //Place activations in user-specified workspace
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration()); mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
} else { } else {
@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if(fwdPassType == FwdPassType.STANDARD){ if (fwdPassType == FwdPassType.STANDARD) {
//Standard feed-forward case //Standard feed-forward case
input = layers[i].activate(input, train, mgr); input = layers[i].activate(input, train, mgr);
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){ } else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
//rnnTimeStep case //rnnTimeStep case
if (layers[i] instanceof RecurrentLayer) { if (layers[i] instanceof RecurrentLayer) {
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr); input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){ } else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) {
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying()); RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying());
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr); input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
} else if (layers[i] instanceof MultiLayerNetwork) { } else if (layers[i] instanceof MultiLayerNetwork) {
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input)); input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
//Validation: Exception if invalid (bad layer implementation) //Validation: Exception if invalid (bad layer implementation)
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)"); validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
if(wsActCloseNext != null){ if (wsActCloseNext != null) {
wsActCloseNext.close(); wsActCloseNext.close();
} }
wsActCloseNext = temp; wsActCloseNext = temp;
temp = null; temp = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
} }
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet) //Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr //Hence: put inputs in working memory -> set back to default for next use of workspace mgr
if(i == 0 && wsm != WorkspaceMode.NONE){ if (i == 0 && wsm != WorkspaceMode.NONE) {
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
if(wsActCloseNext != null){ if(wsActCloseNext != null){
wsActCloseNext.close(); try {
wsActCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(temp != null){ if(temp != null){
//Should only be non-null on exception //Should only be non-null on exception
while(temp.isScopeActive()){ while(temp.isScopeActive()){
//For safety, should never occur in theory: a single close() call may not be sufficient, if //For safety, should never occur in theory: a single close() call may not be sufficient, if
// workspace scope was borrowed and not properly closed when exception occurred // workspace scope was borrowed and not properly closed when exception occurred
temp.close(); try{
temp.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = layers.length - 1; i >= 0; i--) { for (int i = layers.length - 1; i >= 0; i--) {
if (layers[i] instanceof FrozenLayer) { if (layers[i] instanceof FrozenLayer) {
break; break;
} }
if(traceLog){ if (traceLog) {
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
} }
@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers //Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD); wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){ try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually //Note that because we're opening activation workspaces not in a simple nested order, we'll manually
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current" // override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
if(!tbptt){ if (!tbptt) {
//Standard case //Standard case
currPair = layers[i].backpropGradient(eps, workspaceMgr); currPair = layers[i].backpropGradient(eps, workspaceMgr);
} else { } else {
@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
} }
} }
if(currPair.getSecond() != null) { if (currPair.getSecond() != null) {
//Edge case: may be null for Embedding layer, for example //Edge case: may be null for Embedding layer, for example
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
false, "Backprop"); false, "Backprop");
@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
currPair = new Pair<>(currPair.getFirst(), currPair = new Pair<>(currPair.getFirst(),
this.layerWiseConfigurations.getInputPreProcess(i) this.layerWiseConfigurations.getInputPreProcess(i)
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr)); .backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
if (i > 0 && currPair.getSecond() != null){ if (i > 0 && currPair.getSecond() != null) {
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i, validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
true, "Backprop"); true, "Backprop");
} }
} }
if(i == 0 ){ if (i == 0) {
if(returnInputActGrad && currPair.getSecond() != null){ if (returnInputActGrad && currPair.getSecond() != null) {
currPair.setSecond(currPair.getSecond().detach()); currPair.setSecond(currPair.getSecond().detach());
} else { } else {
currPair.setSecond(null); currPair.setSecond(null);
} }
} }
if(wsActGradCloseNext != null){ if (wsActGradCloseNext != null) {
wsActGradCloseNext.close(); wsActGradCloseNext.close();
} }
wsActGradCloseNext = wsActGradTemp; wsActGradCloseNext = wsActGradTemp;
wsActGradTemp = null; wsActGradTemp = null;
} }
if(traceLog){ if (traceLog) {
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
} }
} }
} catch (Throwable thr ){
t = thr;
} finally { } finally {
if(wsActGradCloseNext != null){ if(wsActGradCloseNext != null){
wsActGradCloseNext.close(); try {
wsActGradCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(wsActGradTemp != null){ if(wsActGradTemp != null) {
//Should only be non-null on exception //Should only be non-null on exception
wsActGradTemp.close(); try {
wsActGradTemp.close();
} catch (Throwable t2) {
if (t != null) {
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }