MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f8364997c0
commit
b85238a6df
|
@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
||||
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
||||
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = 0; i <= stopIndex; i++) {
|
||||
GraphVertex current = vertices[topologicalOrder[i]];
|
||||
|
@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||
.build();
|
||||
|
||||
if(detachedInputs){
|
||||
if (detachedInputs) {
|
||||
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be
|
||||
// in a workspace
|
||||
workspaceMgr.setScopedOutFor(ArrayType.INPUT);
|
||||
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
|
||||
} else {
|
||||
//Don't leverage out of async MultiDataSetIterator workspaces
|
||||
if(features[0].isAttached()){
|
||||
if (features[0].isAttached()) {
|
||||
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
|
||||
}
|
||||
}
|
||||
|
@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
if (ArrayUtils.contains(layerIndexes, vIdx)) {
|
||||
isRequiredOutput = true;
|
||||
|
||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
||||
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||
//Place activations in user-specified workspace
|
||||
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||
|
@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Open the relevant workspace for the activations.
|
||||
//Note that this will be closed only once the current vertex's activations have been consumed
|
||||
MemoryWorkspace wsActivations = null;
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
||||
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
||||
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
|
||||
openActivationsWorkspaces.put(wsActivations, workspaceMgr);
|
||||
}
|
||||
|
@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace
|
||||
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
|
||||
// close these workspaces, the "current" workspace may be set to the incorrect one
|
||||
if(wsActivations != null )
|
||||
if (wsActivations != null)
|
||||
wsActivations.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
|
||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
||||
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
||||
if (closeAtEndIteraton[closeableAt] == null) {
|
||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||
}
|
||||
|
@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
out = features[vIdx];
|
||||
} else {
|
||||
|
||||
if(fwdPassType == FwdPassType.STANDARD){
|
||||
if (fwdPassType == FwdPassType.STANDARD) {
|
||||
//Standard feed-forward case
|
||||
out = current.doForward(train, workspaceMgr);
|
||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
||||
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||
if (current.hasLayer()) {
|
||||
//Layer
|
||||
INDArray input = current.getInputs()[0];
|
||||
Layer l = current.getLayer();
|
||||
if (l instanceof RecurrentLayer) {
|
||||
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying());
|
||||
} else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
|
||||
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||
} else if (l instanceof MultiLayerNetwork) {
|
||||
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
|
||||
|
@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
|
||||
}
|
||||
|
||||
if(inputsTo != null) { //Output vertices may not input to any other vertices
|
||||
if (inputsTo != null) { //Output vertices may not input to any other vertices
|
||||
for (VertexIndices v : inputsTo) {
|
||||
//Note that we don't have to do anything special here: the activations are always detached in
|
||||
// this method
|
||||
|
@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
}
|
||||
}
|
||||
|
||||
if(clearLayerInputs) {
|
||||
if (clearLayerInputs) {
|
||||
current.clear();
|
||||
}
|
||||
|
||||
if(isRequiredOutput){
|
||||
if (isRequiredOutput) {
|
||||
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
|
||||
if(origWSAct != null){
|
||||
if (origWSAct != null) {
|
||||
//Reset the configuration, as we may reuse this workspace manager...
|
||||
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
|
||||
}
|
||||
|
@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Close any activations workspaces that we no longer require
|
||||
//Note that activations workspaces can be closed only once the corresponding output activations have
|
||||
// been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||
//Though if stopIndex < numLayers, some might still be open
|
||||
|
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
||||
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
||||
// number of times to actually close it, in all cases
|
||||
try{
|
||||
ws.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
|
@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
try{
|
||||
for(int i=topologicalOrder.length-1; i>= 0; i--){
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
||||
boolean hitFrozen = false;
|
||||
GraphVertex current = vertices[topologicalOrder[i]];
|
||||
int vIdx = current.getVertexIndex();
|
||||
String vertexName = current.getVertexName();
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//FIXME: make the frozen vertex feature extraction more flexible
|
||||
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){
|
||||
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
|
||||
hitFrozen = true;
|
||||
}
|
||||
|
||||
if (current.isInputVertex() || hitFrozen){
|
||||
if (current.isInputVertex() || hitFrozen) {
|
||||
//Close any activation gradient workspaces that we no longer require
|
||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||
// gradients have been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
|
@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
wsActivationGrads.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
|
||||
if(closeableAt >= 0) {
|
||||
if (closeableAt >= 0) {
|
||||
if (closeAtEndIteraton[closeableAt] == null) {
|
||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||
}
|
||||
|
@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
|
||||
Pair<Gradient, INDArray[]> pair;
|
||||
INDArray[] epsilons;
|
||||
try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
||||
try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||
pair = current.doBackward(truncatedBPTT, workspaceMgr);
|
||||
epsilons = pair.getSecond();
|
||||
|
||||
//Validate workspace location for the activation gradients:
|
||||
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
|
||||
for (INDArray epsilon : epsilons) {
|
||||
if(epsilon != null) {
|
||||
if (epsilon != null) {
|
||||
//May be null for EmbeddingLayer, etc
|
||||
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
|
||||
}
|
||||
|
@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
||||
g.flatteningOrderForVariable(origName)));
|
||||
}
|
||||
for (Triple<String, INDArray, Character> t : tempList)
|
||||
gradients.addFirst(t);
|
||||
for (Triple<String, INDArray, Character> triple : tempList)
|
||||
gradients.addFirst(triple);
|
||||
}
|
||||
|
||||
//Close any activation gradient workspaces that we no longer require
|
||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||
// gradients have been fully consumed
|
||||
if(closeAtEndIteraton[i] != null){
|
||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
||||
if (closeAtEndIteraton[i] != null) {
|
||||
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||
wsAct.close();
|
||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||
freeWorkspaceManagers.add(canNowReuse);
|
||||
|
@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
|||
closeAtEndIteraton[i] = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
||||
try{
|
||||
ws.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
}
|
||||
|
||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||
Gradient gradient = new DefaultGradient(flattenedGradients);
|
||||
for (Triple<String, INDArray, Character> t : gradients) {
|
||||
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
|
||||
for (Triple<String, INDArray, Character> tr : gradients) {
|
||||
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
|
||||
}
|
||||
|
||||
this.gradient = gradient;
|
||||
|
|
|
@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = 0; i <= layerIndex; i++) {
|
||||
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||
//Hence: put inputs in working memory
|
||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
||||
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
|
||||
}
|
||||
|
||||
|
@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
temp.setPreviousWorkspace(initialWorkspace);
|
||||
|
||||
|
||||
if(i == 0 && input.isAttached()){
|
||||
if (i == 0 && input.isAttached()) {
|
||||
//Don't leverage out of async DataSetIterator workspaces
|
||||
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
|
||||
}
|
||||
|
@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
|
||||
}
|
||||
|
||||
if ( i == layerIndex ) {
|
||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
||||
if (i == layerIndex) {
|
||||
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||
//Place activations in user-specified workspace
|
||||
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
|
||||
} else {
|
||||
|
@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
|
||||
if(fwdPassType == FwdPassType.STANDARD){
|
||||
if (fwdPassType == FwdPassType.STANDARD) {
|
||||
//Standard feed-forward case
|
||||
input = layers[i].activate(input, train, mgr);
|
||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
||||
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||
//rnnTimeStep case
|
||||
if (layers[i] instanceof RecurrentLayer) {
|
||||
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying());
|
||||
} else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) {
|
||||
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying());
|
||||
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
||||
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
|
||||
|
@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
//Validation: Exception if invalid (bad layer implementation)
|
||||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
|
||||
|
||||
if(wsActCloseNext != null){
|
||||
if (wsActCloseNext != null) {
|
||||
wsActCloseNext.close();
|
||||
}
|
||||
wsActCloseNext = temp;
|
||||
temp = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr
|
||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
||||
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
||||
}
|
||||
}
|
||||
|
||||
} catch (Throwable t2){
|
||||
t = t2;
|
||||
} finally {
|
||||
if(wsActCloseNext != null){
|
||||
try {
|
||||
wsActCloseNext.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
if(temp != null){
|
||||
//Should only be non-null on exception
|
||||
while(temp.isScopeActive()){
|
||||
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
||||
// workspace scope was borrowed and not properly closed when exception occurred
|
||||
try{
|
||||
temp.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
boolean traceLog = log.isTraceEnabled();
|
||||
|
||||
Throwable t = null;
|
||||
try {
|
||||
for (int i = layers.length - 1; i >= 0; i--) {
|
||||
if (layers[i] instanceof FrozenLayer) {
|
||||
break;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
|
||||
|
@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
|
||||
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
|
||||
try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
||||
try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||
|
||||
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually
|
||||
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
|
||||
|
@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
|
||||
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
|
||||
|
||||
if(!tbptt){
|
||||
if (!tbptt) {
|
||||
//Standard case
|
||||
currPair = layers[i].backpropGradient(eps, workspaceMgr);
|
||||
} else {
|
||||
|
@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
}
|
||||
}
|
||||
|
||||
if(currPair.getSecond() != null) {
|
||||
if (currPair.getSecond() != null) {
|
||||
//Edge case: may be null for Embedding layer, for example
|
||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||
false, "Backprop");
|
||||
|
@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
|||
currPair = new Pair<>(currPair.getFirst(),
|
||||
this.layerWiseConfigurations.getInputPreProcess(i)
|
||||
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
|
||||
if (i > 0 && currPair.getSecond() != null){
|
||||
if (i > 0 && currPair.getSecond() != null) {
|
||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||
true, "Backprop");
|
||||
}
|
||||
}
|
||||
|
||||
if(i == 0 ){
|
||||
if(returnInputActGrad && currPair.getSecond() != null){
|
||||
if (i == 0) {
|
||||
if (returnInputActGrad && currPair.getSecond() != null) {
|
||||
currPair.setSecond(currPair.getSecond().detach());
|
||||
} else {
|
||||
currPair.setSecond(null);
|
||||
}
|
||||
}
|
||||
|
||||
if(wsActGradCloseNext != null){
|
||||
if (wsActGradCloseNext != null) {
|
||||
wsActGradCloseNext.close();
|
||||
}
|
||||
wsActGradCloseNext = wsActGradTemp;
|
||||
wsActGradTemp = null;
|
||||
}
|
||||
|
||||
if(traceLog){
|
||||
if (traceLog) {
|
||||
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||
}
|
||||
}
|
||||
} catch (Throwable thr ){
|
||||
t = thr;
|
||||
} finally {
|
||||
if(wsActGradCloseNext != null){
|
||||
try {
|
||||
wsActGradCloseNext.close();
|
||||
} catch (Throwable t2){
|
||||
if(t != null){
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
if(wsActGradTemp != null){
|
||||
}
|
||||
}
|
||||
if(wsActGradTemp != null) {
|
||||
//Should only be non-null on exception
|
||||
try {
|
||||
wsActGradTemp.close();
|
||||
} catch (Throwable t2) {
|
||||
if (t != null) {
|
||||
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||
log.error("Original exception:", t);
|
||||
throw t2;
|
||||
}
|
||||
}
|
||||
}
|
||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue