MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-08-24 17:33:11 +10:00 committed by GitHub
parent f8364997c0
commit b85238a6df
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 63 deletions

View File

@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null; LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length]; List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace(); MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
Throwable t = null;
try { try {
for (int i = 0; i <= stopIndex; i++) { for (int i = 0; i <= stopIndex; i++) {
GraphVertex current = vertices[topologicalOrder[i]]; GraphVertex current = vertices[topologicalOrder[i]];
@ -2436,6 +2437,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
} }
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
//Though if stopIndex < numLayers, some might still be open //Though if stopIndex < numLayers, some might still be open
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
//Edge case here: seems that scoping out can increase the tagScope of the current WS //Edge case here: seems that scoping out can increase the tagScope of the current WS
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient //and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
// number of times to actually close it, in all cases // number of times to actually close it, in all cases
try{
ws.close(); ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
@ -2581,6 +2592,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = topologicalOrder.length - 1; i >= 0; i--) { for (int i = topologicalOrder.length - 1; i >= 0; i--) {
boolean hitFrozen = false; boolean hitFrozen = false;
@ -2732,8 +2744,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
tempList.addFirst(new Triple<>(newName, entry.getValue(), tempList.addFirst(new Triple<>(newName, entry.getValue(),
g.flatteningOrderForVariable(origName))); g.flatteningOrderForVariable(origName)));
} }
for (Triple<String, INDArray, Character> t : tempList) for (Triple<String, INDArray, Character> triple : tempList)
gradients.addFirst(t); gradients.addFirst(triple);
} }
//Close any activation gradient workspaces that we no longer require //Close any activation gradient workspaces that we no longer require
@ -2752,19 +2764,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName()); log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown //Close all open workspaces... usually this list will be empty, but not if an exception is thrown
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){ for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
try{
ws.close(); ws.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }
//Now, add the gradients in the order we need them in for flattening (same as params order) //Now, add the gradients in the order we need them in for flattening (same as params order)
Gradient gradient = new DefaultGradient(flattenedGradients); Gradient gradient = new DefaultGradient(flattenedGradients);
for (Triple<String, INDArray, Character> t : gradients) { for (Triple<String, INDArray, Character> tr : gradients) {
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird()); gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
} }
this.gradient = gradient; this.gradient = gradient;

View File

@ -1242,6 +1242,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = 0; i <= layerIndex; i++) { for (int i = 0; i <= layerIndex; i++) {
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd); LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
@ -1328,17 +1329,34 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
} }
} }
} catch (Throwable t2){
t = t2;
} finally { } finally {
if(wsActCloseNext != null){ if(wsActCloseNext != null){
try {
wsActCloseNext.close(); wsActCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(temp != null){ if(temp != null){
//Should only be non-null on exception //Should only be non-null on exception
while(temp.isScopeActive()){ while(temp.isScopeActive()){
//For safety, should never occur in theory: a single close() call may not be sufficient, if //For safety, should never occur in theory: a single close() call may not be sufficient, if
// workspace scope was borrowed and not properly closed when exception occurred // workspace scope was borrowed and not properly closed when exception occurred
try{
temp.close(); temp.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
} }
@ -1871,6 +1889,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
boolean traceLog = log.isTraceEnabled(); boolean traceLog = log.isTraceEnabled();
Throwable t = null;
try { try {
for (int i = layers.length - 1; i >= 0; i--) { for (int i = layers.length - 1; i >= 0; i--) {
if (layers[i] instanceof FrozenLayer) { if (layers[i] instanceof FrozenLayer) {
@ -1961,13 +1980,31 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName()); log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
} }
} }
} catch (Throwable thr ){
t = thr;
} finally { } finally {
if(wsActGradCloseNext != null){ if(wsActGradCloseNext != null){
try {
wsActGradCloseNext.close(); wsActGradCloseNext.close();
} catch (Throwable t2){
if(t != null){
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
if(wsActGradTemp != null) { if(wsActGradTemp != null) {
//Should only be non-null on exception //Should only be non-null on exception
try {
wsActGradTemp.close(); wsActGradTemp.close();
} catch (Throwable t2) {
if (t != null) {
log.error("Encountered second exception while trying to close workspace after initial exception");
log.error("Original exception:", t);
throw t2;
}
}
} }
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace); Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
} }