MLN/CG: Don't swallow exceptions if a second exception occurs during workspace closing (#161)
Signed-off-by: AlexDBlack <blacka101@gmail.com>master
parent
f8364997c0
commit
b85238a6df
|
@ -2278,6 +2278,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
LayerWorkspaceMgr allNone = noWS ? LayerWorkspaceMgr.noWorkspaces(helperWorkspaces) : null;
|
||||||
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
List<MemoryWorkspace>[] closeAtEndIteraton = (List<MemoryWorkspace>[])new List[topologicalOrder.length];
|
||||||
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
MemoryWorkspace initialWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= stopIndex; i++) {
|
for (int i = 0; i <= stopIndex; i++) {
|
||||||
GraphVertex current = vertices[topologicalOrder[i]];
|
GraphVertex current = vertices[topologicalOrder[i]];
|
||||||
|
@ -2302,14 +2303,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
.with(ArrayType.RNN_FF_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM, WS_RNN_LOOP_WORKING_MEM_CONFIG)
|
||||||
.build();
|
.build();
|
||||||
|
|
||||||
if(detachedInputs){
|
if (detachedInputs) {
|
||||||
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be
|
//Sometimes (like: external errors use cases) we don't want the activations/inputs to be
|
||||||
// in a workspace
|
// in a workspace
|
||||||
workspaceMgr.setScopedOutFor(ArrayType.INPUT);
|
workspaceMgr.setScopedOutFor(ArrayType.INPUT);
|
||||||
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
|
workspaceMgr.setScopedOutFor(ArrayType.ACTIVATIONS);
|
||||||
} else {
|
} else {
|
||||||
//Don't leverage out of async MultiDataSetIterator workspaces
|
//Don't leverage out of async MultiDataSetIterator workspaces
|
||||||
if(features[0].isAttached()){
|
if (features[0].isAttached()) {
|
||||||
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
|
workspaceMgr.setNoLeverageOverride(features[0].data().getParentWorkspace().getId());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2326,7 +2327,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
if (ArrayUtils.contains(layerIndexes, vIdx)) {
|
if (ArrayUtils.contains(layerIndexes, vIdx)) {
|
||||||
isRequiredOutput = true;
|
isRequiredOutput = true;
|
||||||
|
|
||||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||||
//Place activations in user-specified workspace
|
//Place activations in user-specified workspace
|
||||||
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
origWSAct = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
|
||||||
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
origWSActConf = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
|
||||||
|
@ -2345,7 +2346,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Open the relevant workspace for the activations.
|
//Open the relevant workspace for the activations.
|
||||||
//Note that this will be closed only once the current vertex's activations have been consumed
|
//Note that this will be closed only once the current vertex's activations have been consumed
|
||||||
MemoryWorkspace wsActivations = null;
|
MemoryWorkspace wsActivations = null;
|
||||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput ){ //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || !isRequiredOutput) { //Open WS if (a) no external/output WS (if present, it's already open), or (b) not being placed in external/output WS
|
||||||
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
|
wsActivations = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATIONS);
|
||||||
openActivationsWorkspaces.put(wsActivations, workspaceMgr);
|
openActivationsWorkspaces.put(wsActivations, workspaceMgr);
|
||||||
}
|
}
|
||||||
|
@ -2353,11 +2354,11 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace
|
//Note that because we're opening activation workspaces not in any defined order (i.e., workspace
|
||||||
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
|
// use isn't simply nested), we'll manually override the previous workspace setting. Otherwise, when we
|
||||||
// close these workspaces, the "current" workspace may be set to the incorrect one
|
// close these workspaces, the "current" workspace may be set to the incorrect one
|
||||||
if(wsActivations != null )
|
if (wsActivations != null)
|
||||||
wsActivations.setPreviousWorkspace(initialWorkspace);
|
wsActivations.setPreviousWorkspace(initialWorkspace);
|
||||||
|
|
||||||
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
|
int closeableAt = vertexOutputsFullyConsumedByStep[vIdx];
|
||||||
if(outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
if (outputWorkspace == null || outputWorkspace instanceof DummyWorkspace || (wsActivations != null && !outputWorkspace.getId().equals(wsActivations.getId()))) {
|
||||||
if (closeAtEndIteraton[closeableAt] == null) {
|
if (closeAtEndIteraton[closeableAt] == null) {
|
||||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
@ -2373,18 +2374,18 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
out = features[vIdx];
|
out = features[vIdx];
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
if(fwdPassType == FwdPassType.STANDARD){
|
if (fwdPassType == FwdPassType.STANDARD) {
|
||||||
//Standard feed-forward case
|
//Standard feed-forward case
|
||||||
out = current.doForward(train, workspaceMgr);
|
out = current.doForward(train, workspaceMgr);
|
||||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||||
if (current.hasLayer()) {
|
if (current.hasLayer()) {
|
||||||
//Layer
|
//Layer
|
||||||
INDArray input = current.getInputs()[0];
|
INDArray input = current.getInputs()[0];
|
||||||
Layer l = current.getLayer();
|
Layer l = current.getLayer();
|
||||||
if (l instanceof RecurrentLayer) {
|
if (l instanceof RecurrentLayer) {
|
||||||
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
out = ((RecurrentLayer) l).rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||||
} else if(l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying() instanceof RecurrentLayer){
|
} else if (l instanceof org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer && ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying() instanceof RecurrentLayer) {
|
||||||
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer)l).getUnderlying());
|
RecurrentLayer rl = ((RecurrentLayer) ((org.deeplearning4j.nn.layers.wrapper.BaseWrapperLayer) l).getUnderlying());
|
||||||
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
out = rl.rnnTimeStep(reshapeTimeStepInput(input), workspaceMgr);
|
||||||
} else if (l instanceof MultiLayerNetwork) {
|
} else if (l instanceof MultiLayerNetwork) {
|
||||||
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
|
out = ((MultiLayerNetwork) l).rnnTimeStep(reshapeTimeStepInput(input));
|
||||||
|
@ -2402,7 +2403,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
|
validateArrayWorkspaces(workspaceMgr, out, ArrayType.ACTIVATIONS, vName, false, "Feed forward (inference)");
|
||||||
}
|
}
|
||||||
|
|
||||||
if(inputsTo != null) { //Output vertices may not input to any other vertices
|
if (inputsTo != null) { //Output vertices may not input to any other vertices
|
||||||
for (VertexIndices v : inputsTo) {
|
for (VertexIndices v : inputsTo) {
|
||||||
//Note that we don't have to do anything special here: the activations are always detached in
|
//Note that we don't have to do anything special here: the activations are always detached in
|
||||||
// this method
|
// this method
|
||||||
|
@ -2412,13 +2413,13 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(clearLayerInputs) {
|
if (clearLayerInputs) {
|
||||||
current.clear();
|
current.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
if(isRequiredOutput){
|
if (isRequiredOutput) {
|
||||||
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
|
outputs[ArrayUtils.indexOf(layerIndexes, vIdx)] = out;
|
||||||
if(origWSAct != null){
|
if (origWSAct != null) {
|
||||||
//Reset the configuration, as we may reuse this workspace manager...
|
//Reset the configuration, as we may reuse this workspace manager...
|
||||||
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
|
workspaceMgr.setWorkspace(ArrayType.ACTIVATIONS, origWSAct, origWSActConf);
|
||||||
}
|
}
|
||||||
|
@ -2428,14 +2429,16 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Close any activations workspaces that we no longer require
|
//Close any activations workspaces that we no longer require
|
||||||
//Note that activations workspaces can be closed only once the corresponding output activations have
|
//Note that activations workspaces can be closed only once the corresponding output activations have
|
||||||
// been fully consumed
|
// been fully consumed
|
||||||
if(closeAtEndIteraton[i] != null){
|
if (closeAtEndIteraton[i] != null) {
|
||||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||||
wsAct.close();
|
wsAct.close();
|
||||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||||
freeWorkspaceManagers.add(canNowReuse);
|
freeWorkspaceManagers.add(canNowReuse);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
//Though if stopIndex < numLayers, some might still be open
|
//Though if stopIndex < numLayers, some might still be open
|
||||||
|
@ -2444,7 +2447,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
//Edge case here: seems that scoping out can increase the tagScope of the current WS
|
||||||
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
//and if we hit an exception during forward pass, we aren't guaranteed to call close a sufficient
|
||||||
// number of times to actually close it, in all cases
|
// number of times to actually close it, in all cases
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
|
@ -2581,28 +2592,29 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
try{
|
Throwable t = null;
|
||||||
for(int i=topologicalOrder.length-1; i>= 0; i--){
|
try {
|
||||||
|
for (int i = topologicalOrder.length - 1; i >= 0; i--) {
|
||||||
boolean hitFrozen = false;
|
boolean hitFrozen = false;
|
||||||
GraphVertex current = vertices[topologicalOrder[i]];
|
GraphVertex current = vertices[topologicalOrder[i]];
|
||||||
int vIdx = current.getVertexIndex();
|
int vIdx = current.getVertexIndex();
|
||||||
String vertexName = current.getVertexName();
|
String vertexName = current.getVertexName();
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
log.trace("About backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
|
|
||||||
//FIXME: make the frozen vertex feature extraction more flexible
|
//FIXME: make the frozen vertex feature extraction more flexible
|
||||||
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex){
|
if (current.hasLayer() && current.getLayer() instanceof FrozenLayer || current instanceof FrozenVertex) {
|
||||||
hitFrozen = true;
|
hitFrozen = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (current.isInputVertex() || hitFrozen){
|
if (current.isInputVertex() || hitFrozen) {
|
||||||
//Close any activation gradient workspaces that we no longer require
|
//Close any activation gradient workspaces that we no longer require
|
||||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||||
// gradients have been fully consumed
|
// gradients have been fully consumed
|
||||||
if(closeAtEndIteraton[i] != null){
|
if (closeAtEndIteraton[i] != null) {
|
||||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||||
wsAct.close();
|
wsAct.close();
|
||||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||||
freeWorkspaceManagers.add(canNowReuse);
|
freeWorkspaceManagers.add(canNowReuse);
|
||||||
|
@ -2680,7 +2692,7 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
wsActivationGrads.setPreviousWorkspace(initialWorkspace);
|
wsActivationGrads.setPreviousWorkspace(initialWorkspace);
|
||||||
|
|
||||||
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
|
int closeableAt = vertexActGradsFullyConsumedByStep[vIdx];
|
||||||
if(closeableAt >= 0) {
|
if (closeableAt >= 0) {
|
||||||
if (closeAtEndIteraton[closeableAt] == null) {
|
if (closeAtEndIteraton[closeableAt] == null) {
|
||||||
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
closeAtEndIteraton[closeableAt] = new ArrayList<>();
|
||||||
}
|
}
|
||||||
|
@ -2689,14 +2701,14 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
|
|
||||||
Pair<Gradient, INDArray[]> pair;
|
Pair<Gradient, INDArray[]> pair;
|
||||||
INDArray[] epsilons;
|
INDArray[] epsilons;
|
||||||
try(MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
try (MemoryWorkspace wsWorkingMem = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||||
pair = current.doBackward(truncatedBPTT, workspaceMgr);
|
pair = current.doBackward(truncatedBPTT, workspaceMgr);
|
||||||
epsilons = pair.getSecond();
|
epsilons = pair.getSecond();
|
||||||
|
|
||||||
//Validate workspace location for the activation gradients:
|
//Validate workspace location for the activation gradients:
|
||||||
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
|
//validateArrayWorkspaces(LayerWorkspaceMgr mgr, INDArray array, ArrayType arrayType, String vertexName, boolean isInputVertex, String op){
|
||||||
for (INDArray epsilon : epsilons) {
|
for (INDArray epsilon : epsilons) {
|
||||||
if(epsilon != null) {
|
if (epsilon != null) {
|
||||||
//May be null for EmbeddingLayer, etc
|
//May be null for EmbeddingLayer, etc
|
||||||
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
|
validateArrayWorkspaces(workspaceMgr, epsilon, ArrayType.ACTIVATION_GRAD, vertexName, false, "Backprop");
|
||||||
}
|
}
|
||||||
|
@ -2732,15 +2744,15 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
tempList.addFirst(new Triple<>(newName, entry.getValue(),
|
||||||
g.flatteningOrderForVariable(origName)));
|
g.flatteningOrderForVariable(origName)));
|
||||||
}
|
}
|
||||||
for (Triple<String, INDArray, Character> t : tempList)
|
for (Triple<String, INDArray, Character> triple : tempList)
|
||||||
gradients.addFirst(t);
|
gradients.addFirst(triple);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Close any activation gradient workspaces that we no longer require
|
//Close any activation gradient workspaces that we no longer require
|
||||||
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
//Note that activation gradient workspaces can be closed only once the corresponding activations
|
||||||
// gradients have been fully consumed
|
// gradients have been fully consumed
|
||||||
if(closeAtEndIteraton[i] != null){
|
if (closeAtEndIteraton[i] != null) {
|
||||||
for(MemoryWorkspace wsAct : closeAtEndIteraton[i]){
|
for (MemoryWorkspace wsAct : closeAtEndIteraton[i]) {
|
||||||
wsAct.close();
|
wsAct.close();
|
||||||
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
LayerWorkspaceMgr canNowReuse = openActivationsWorkspaces.remove(wsAct);
|
||||||
freeWorkspaceManagers.add(canNowReuse);
|
freeWorkspaceManagers.add(canNowReuse);
|
||||||
|
@ -2748,23 +2760,32 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork {
|
||||||
closeAtEndIteraton[i] = null;
|
closeAtEndIteraton[i] = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
log.trace("Completed backprop: {} (\"{}\") - {}", i, vertexName, current.getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
//Close all open workspaces... usually this list will be empty, but not if an exception is thrown
|
||||||
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
for(MemoryWorkspace ws : openActivationsWorkspaces.keySet()){
|
||||||
|
try{
|
||||||
ws.close();
|
ws.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
//Now, add the gradients in the order we need them in for flattening (same as params order)
|
||||||
Gradient gradient = new DefaultGradient(flattenedGradients);
|
Gradient gradient = new DefaultGradient(flattenedGradients);
|
||||||
for (Triple<String, INDArray, Character> t : gradients) {
|
for (Triple<String, INDArray, Character> tr : gradients) {
|
||||||
gradient.setGradientFor(t.getFirst(), t.getSecond(), t.getThird());
|
gradient.setGradientFor(tr.getFirst(), tr.getSecond(), tr.getThird());
|
||||||
}
|
}
|
||||||
|
|
||||||
this.gradient = gradient;
|
this.gradient = gradient;
|
||||||
|
|
|
@ -1242,17 +1242,18 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = 0; i <= layerIndex; i++) {
|
for (int i = 0; i <= layerIndex; i++) {
|
||||||
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
LayerWorkspaceMgr mgr = (i % 2 == 0 ? mgrEven : mgrOdd);
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("About to forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||||
//Hence: put inputs in working memory
|
//Hence: put inputs in working memory
|
||||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
|
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_WORKING_MEM, WS_LAYER_WORKING_MEM_CONFIG);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1268,7 +1269,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
temp.setPreviousWorkspace(initialWorkspace);
|
temp.setPreviousWorkspace(initialWorkspace);
|
||||||
|
|
||||||
|
|
||||||
if(i == 0 && input.isAttached()){
|
if (i == 0 && input.isAttached()) {
|
||||||
//Don't leverage out of async DataSetIterator workspaces
|
//Don't leverage out of async DataSetIterator workspaces
|
||||||
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
|
mgr.setNoLeverageOverride(input.data().getParentWorkspace().getId());
|
||||||
}
|
}
|
||||||
|
@ -1279,8 +1280,8 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
|
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, true, "Output of layer (inference)");
|
||||||
}
|
}
|
||||||
|
|
||||||
if ( i == layerIndex ) {
|
if (i == layerIndex) {
|
||||||
if(outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)){
|
if (outputWorkspace != null && !(outputWorkspace instanceof DummyWorkspace)) {
|
||||||
//Place activations in user-specified workspace
|
//Place activations in user-specified workspace
|
||||||
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
|
mgr.setWorkspace(ArrayType.ACTIVATIONS, outputWorkspace.getId(), outputWorkspace.getWorkspaceConfiguration());
|
||||||
} else {
|
} else {
|
||||||
|
@ -1289,15 +1290,15 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(fwdPassType == FwdPassType.STANDARD){
|
if (fwdPassType == FwdPassType.STANDARD) {
|
||||||
//Standard feed-forward case
|
//Standard feed-forward case
|
||||||
input = layers[i].activate(input, train, mgr);
|
input = layers[i].activate(input, train, mgr);
|
||||||
} else if(fwdPassType == FwdPassType.RNN_TIMESTEP){
|
} else if (fwdPassType == FwdPassType.RNN_TIMESTEP) {
|
||||||
//rnnTimeStep case
|
//rnnTimeStep case
|
||||||
if (layers[i] instanceof RecurrentLayer) {
|
if (layers[i] instanceof RecurrentLayer) {
|
||||||
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
input = ((RecurrentLayer) layers[i]).rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||||
} else if(layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer)layers[i]).getUnderlying() instanceof RecurrentLayer){
|
} else if (layers[i] instanceof BaseWrapperLayer && ((BaseWrapperLayer) layers[i]).getUnderlying() instanceof RecurrentLayer) {
|
||||||
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer)layers[i]).getUnderlying());
|
RecurrentLayer rl = ((RecurrentLayer) ((BaseWrapperLayer) layers[i]).getUnderlying());
|
||||||
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
input = rl.rnnTimeStep(reshapeTimeStepInput(input), mgr);
|
||||||
} else if (layers[i] instanceof MultiLayerNetwork) {
|
} else if (layers[i] instanceof MultiLayerNetwork) {
|
||||||
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
|
input = ((MultiLayerNetwork) layers[i]).rnnTimeStep(reshapeTimeStepInput(input));
|
||||||
|
@ -1311,34 +1312,51 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
//Validation: Exception if invalid (bad layer implementation)
|
//Validation: Exception if invalid (bad layer implementation)
|
||||||
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
|
validateArrayWorkspaces(mgr, input, ArrayType.ACTIVATIONS, i, false, "Output of layer (inference)");
|
||||||
|
|
||||||
if(wsActCloseNext != null){
|
if (wsActCloseNext != null) {
|
||||||
wsActCloseNext.close();
|
wsActCloseNext.close();
|
||||||
}
|
}
|
||||||
wsActCloseNext = temp;
|
wsActCloseNext = temp;
|
||||||
temp = null;
|
temp = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("Completed forward pass: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
|
|
||||||
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
//Edge case: for first layer with dropout, inputs can't be in previous workspace (as it hasn't been opened yet)
|
||||||
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr
|
//Hence: put inputs in working memory -> set back to default for next use of workspace mgr
|
||||||
if(i == 0 && wsm != WorkspaceMode.NONE){
|
if (i == 0 && wsm != WorkspaceMode.NONE) {
|
||||||
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
mgr.setWorkspace(ArrayType.INPUT, WS_LAYER_ACT_2, WS_LAYER_ACT_X_CONFIG); //Inputs should always be in the previous WS
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable t2){
|
||||||
|
t = t2;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActCloseNext != null){
|
if(wsActCloseNext != null){
|
||||||
|
try {
|
||||||
wsActCloseNext.close();
|
wsActCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if(temp != null){
|
if(temp != null){
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
while(temp.isScopeActive()){
|
while(temp.isScopeActive()){
|
||||||
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
//For safety, should never occur in theory: a single close() call may not be sufficient, if
|
||||||
// workspace scope was borrowed and not properly closed when exception occurred
|
// workspace scope was borrowed and not properly closed when exception occurred
|
||||||
|
try{
|
||||||
temp.close();
|
temp.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1871,13 +1889,14 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
boolean traceLog = log.isTraceEnabled();
|
boolean traceLog = log.isTraceEnabled();
|
||||||
|
|
||||||
|
Throwable t = null;
|
||||||
try {
|
try {
|
||||||
for (int i = layers.length - 1; i >= 0; i--) {
|
for (int i = layers.length - 1; i >= 0; i--) {
|
||||||
if (layers[i] instanceof FrozenLayer) {
|
if (layers[i] instanceof FrozenLayer) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("About to backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1897,7 +1916,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
|
//Open activation gradients WS *then* BP working memory, so BP working memory is opened last for use in layers
|
||||||
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
|
wsActGradTemp = workspaceMgr.notifyScopeEntered(ArrayType.ACTIVATION_GRAD);
|
||||||
try(MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)){
|
try (MemoryWorkspace wsBPWorking = workspaceMgr.notifyScopeEntered(ArrayType.BP_WORKING_MEM)) {
|
||||||
|
|
||||||
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually
|
//Note that because we're opening activation workspaces not in a simple nested order, we'll manually
|
||||||
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
|
// override the previous workspace setting. Otherwise, when we close these workspaces, the "current"
|
||||||
|
@ -1907,7 +1926,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
|
|
||||||
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
|
INDArray eps = (i == layers.length - 1 ? epsilon : currPair.getRight()); //eps is null for OutputLayer
|
||||||
|
|
||||||
if(!tbptt){
|
if (!tbptt) {
|
||||||
//Standard case
|
//Standard case
|
||||||
currPair = layers[i].backpropGradient(eps, workspaceMgr);
|
currPair = layers[i].backpropGradient(eps, workspaceMgr);
|
||||||
} else {
|
} else {
|
||||||
|
@ -1920,7 +1939,7 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(currPair.getSecond() != null) {
|
if (currPair.getSecond() != null) {
|
||||||
//Edge case: may be null for Embedding layer, for example
|
//Edge case: may be null for Embedding layer, for example
|
||||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||||
false, "Backprop");
|
false, "Backprop");
|
||||||
|
@ -1936,38 +1955,56 @@ public class MultiLayerNetwork implements Serializable, Classifier, Layer, Neura
|
||||||
currPair = new Pair<>(currPair.getFirst(),
|
currPair = new Pair<>(currPair.getFirst(),
|
||||||
this.layerWiseConfigurations.getInputPreProcess(i)
|
this.layerWiseConfigurations.getInputPreProcess(i)
|
||||||
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
|
.backprop(currPair.getSecond(), getInputMiniBatchSize(), workspaceMgr));
|
||||||
if (i > 0 && currPair.getSecond() != null){
|
if (i > 0 && currPair.getSecond() != null) {
|
||||||
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
validateArrayWorkspaces(workspaceMgr, currPair.getSecond(), ArrayType.ACTIVATION_GRAD, i,
|
||||||
true, "Backprop");
|
true, "Backprop");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(i == 0 ){
|
if (i == 0) {
|
||||||
if(returnInputActGrad && currPair.getSecond() != null){
|
if (returnInputActGrad && currPair.getSecond() != null) {
|
||||||
currPair.setSecond(currPair.getSecond().detach());
|
currPair.setSecond(currPair.getSecond().detach());
|
||||||
} else {
|
} else {
|
||||||
currPair.setSecond(null);
|
currPair.setSecond(null);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if(wsActGradCloseNext != null){
|
if (wsActGradCloseNext != null) {
|
||||||
wsActGradCloseNext.close();
|
wsActGradCloseNext.close();
|
||||||
}
|
}
|
||||||
wsActGradCloseNext = wsActGradTemp;
|
wsActGradCloseNext = wsActGradTemp;
|
||||||
wsActGradTemp = null;
|
wsActGradTemp = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
if(traceLog){
|
if (traceLog) {
|
||||||
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
log.trace("Completed backprop: {} - {}", i, layers[i].getClass().getSimpleName());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} catch (Throwable thr ){
|
||||||
|
t = thr;
|
||||||
} finally {
|
} finally {
|
||||||
if(wsActGradCloseNext != null){
|
if(wsActGradCloseNext != null){
|
||||||
|
try {
|
||||||
wsActGradCloseNext.close();
|
wsActGradCloseNext.close();
|
||||||
|
} catch (Throwable t2){
|
||||||
|
if(t != null){
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
}
|
}
|
||||||
if(wsActGradTemp != null){
|
}
|
||||||
|
}
|
||||||
|
if(wsActGradTemp != null) {
|
||||||
//Should only be non-null on exception
|
//Should only be non-null on exception
|
||||||
|
try {
|
||||||
wsActGradTemp.close();
|
wsActGradTemp.close();
|
||||||
|
} catch (Throwable t2) {
|
||||||
|
if (t != null) {
|
||||||
|
log.error("Encountered second exception while trying to close workspace after initial exception");
|
||||||
|
log.error("Original exception:", t);
|
||||||
|
throw t2;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
Nd4j.getMemoryManager().setCurrentWorkspace(initialWorkspace);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue