MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f8364997c0
commit
b85238a6df
|
@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
||||||
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
||||||
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= stopIndex; i++) {
|
for (int i = 0; i <= stopIndex; i++) {
|
||||||
GraphVertex current = vertices[topologicalOrder[i]];
|
GraphVertex current = vertices[topologicalOrder[i]];
|
||||||
|
@ -2436,6 +2437,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
//Though if stopIndex < numLayers, some might still be open
|
//Though if stopIndex < numLayers, some might still be open
|
||||||
|
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
||||||
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
||||||
// number of times to actually close it, in all cases
|
// number of times to actually close it, in all cases
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
|
@ -2581,6 +2592,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
||||||
boolean hitFrozen = false;
|
boolean hitFrozen = false;
|
||||||
|
@ -2732,8 +2744,8 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
||||||
g.flatteningOrderForVariable(origName)));
|
g.flatteningOrderForVariable(origName)));
|
||||||
}
|
}
|
||||||
for (Triple<String, INDArray, Character> t : tempList)
|
for (Triple<String, INDArray, Character> triple : tempList)
|
||||||
gradients.addFirst(t);
|
gradients.addFirst(triple);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Close any activation gradient workspaces that we no longer require
|
//Close any activation gradient workspaces that we no longer require
|
||||||
|
@ -2752,19 +2764,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||||
Gradient gradient = new DefaultGradient(flattenedGradients);
|
Gradient gradient = new DefaultGradient(flattenedGradients);
|
||||||
for (Triple<String, INDArray, Character> t : gradients) {
|
for (Triple<String, INDArray, Character> tr : gradients) {
|
||||||
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
|
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
|
||||||
}
|
}
|
||||||
|
|
||||||
this.gradient = gradient;
|
this.gradient = gradient;
|
||||||
|
|
|
@ -1242,6 +1242,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= layerIndex; i++) {
|
for (int i = 0; i <= layerIndex; i++) {
|
||||||
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
||||||
|
@ -1328,17 +1329,34 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActCloseNext != null){
|
if(wsActCloseNext != null){
|
||||||
|
try {
|
||||||
wsActCloseNext.close();
|
wsActCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(temp != null){
|
if(temp != null){
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
while(temp.isScopeActive()){
|
while(temp.isScopeActive()){
|
||||||
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
||||||
// workspace scope was borrowed and not properly closed when exception occurred
|
// workspace scope was borrowed and not properly closed when exception occurred
|
||||||
|
try{
|
||||||
temp.close();
|
temp.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1871,6 +1889,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = layers.length - 1; i >= 0; i--) {
|
for (int i = layers.length - 1; i >= 0; i--) {
|
||||||
if (layers[i] instanceof FrozenLayer) {
|
if (layers[i] instanceof FrozenLayer) {
|
||||||
|
@ -1961,13 +1980,31 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable thr ){
|
||||||
|
t = thr;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActGradCloseNext != null){
|
if(wsActGradCloseNext != null){
|
||||||
|
try {
|
||||||
wsActGradCloseNext.close();
|
wsActGradCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(wsActGradTemp != null) {
|
if(wsActGradTemp != null) {
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
|
try {
|
||||||
wsActGradTemp.close();
|
wsActGradTemp.close();
|
||||||
|
} catch (Throwable t2) {
|
||||||
|
if (t != null) {
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue