diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java new file mode 100644 index 000000000..882ec8479 --- /dev/null +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/misc/CloseNetworkTests.java @@ -0,0 +1,151 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.deeplearning4j.nn.misc; + +import org.deeplearning4j.BaseDL4JTest; +import org.deeplearning4j.TestUtils; +import org.deeplearning4j.nn.api.Updater; +import org.deeplearning4j.nn.conf.MultiLayerConfiguration; +import org.deeplearning4j.nn.conf.NeuralNetConfiguration; +import org.deeplearning4j.nn.conf.inputs.InputType; +import org.deeplearning4j.nn.conf.layers.*; +import org.deeplearning4j.nn.graph.ComputationGraph; +import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Test; +import org.nd4j.linalg.activations.Activation; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.learning.config.Adam; + +import static org.junit.Assert.assertTrue; + +public class CloseNetworkTests extends BaseDL4JTest { + + public static MultiLayerNetwork getTestNet() { + MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() + .updater(new Adam(1e-3)) + .list() + .layer(new ConvolutionLayer.Builder().nOut(5).kernelSize(3, 3).activation(Activation.TANH).build()) + .layer(new BatchNormalization.Builder().nOut(5).build()) + .layer(new SubsamplingLayer.Builder().build()) + .layer(new DenseLayer.Builder().nOut(10).activation(Activation.RELU).build()) + .layer(new OutputLayer.Builder().nOut(10).build()) + .setInputType(InputType.convolutional(28, 28, 1)) + .build(); + + MultiLayerNetwork net = new MultiLayerNetwork(conf); + net.init(); + + return net; + } + + @Test + public void testCloseMLN() { + for (boolean train : new boolean[]{false, true}) { + for (boolean test : new boolean[]{false, true}) { + MultiLayerNetwork net = getTestNet(); + + INDArray f = Nd4j.rand(DataType.FLOAT, 16, 1, 28, 28); + INDArray l = TestUtils.randomOneHot(16, 10); + + if (train) { + for (int i = 0; i < 3; i++) { + net.fit(f, l); + } + } + + if (test) { + for (int i = 0; i < 3; i++) { + net.output(f); + } + } + + net.close(); + + assertTrue(net.params().wasClosed()); + if(train) { + assertTrue(net.getGradientsViewArray().wasClosed()); + Updater u = net.getUpdater(false); + assertTrue(u.getStateViewArray().wasClosed()); + } + + //Make sure we don't get crashes etc when trying to use after closing + try { + net.output(f); + } catch (IllegalStateException e) { + String msg = e.getMessage(); + assertTrue(msg, msg.contains("released")); + } + + try { + net.fit(f, l); + } catch (IllegalStateException e) { + String msg = e.getMessage(); + assertTrue(msg, msg.contains("released")); + } + } + } + } + + @Test + public void testCloseCG() { + for (boolean train : new boolean[]{false, true}) { + for (boolean test : new boolean[]{false, true}) { + ComputationGraph net = getTestNet().toComputationGraph(); + + INDArray f = Nd4j.rand(DataType.FLOAT, 16, 1, 28, 28); + INDArray l = TestUtils.randomOneHot(16, 10); + + if (train) { + for (int i = 0; i < 3; i++) { + net.fit(new INDArray[]{f}, new INDArray[]{l}); + } + } + + if (test) { + for (int i = 0; i < 3; i++) { + net.output(f); + } + } + + net.close(); + + assertTrue(net.params().wasClosed()); + if(train) { + assertTrue(net.getGradientsViewArray().wasClosed()); + Updater u = net.getUpdater(false); + assertTrue(u.getStateViewArray().wasClosed()); + } + + //Make sure we don't get crashes etc when trying to use after closing + try { + net.output(f); + } catch (IllegalStateException e) { + String msg = e.getMessage(); + assertTrue(msg, msg.contains("released")); + } + + try { + net.fit(new INDArray[]{f}, new INDArray[]{l}); + } catch (IllegalStateException e) { + String msg = e.getMessage(); + assertTrue(msg, msg.contains("released")); + } + } + } + } +} diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java index 30cce7c28..f79763bfe 100644 --- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java +++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/optimize/solver/TestOptimizers.java @@ -1035,5 +1035,9 @@ public class TestOptimizers extends BaseDL4JTest { public boolean updaterDivideByMinibatch(String paramName) { return true; } + + @Override + public void close(){ + } } } diff --git a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java index a9f92cefc..8cd984044 100644 --- a/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java +++ b/deeplearning4j/deeplearning4j-manifold/deeplearning4j-tsne/src/main/java/org/deeplearning4j/plot/BarnesHutTsne.java @@ -1055,4 +1055,9 @@ public class BarnesHutTsne implements Model { } + + @Override + public void close(){ + //No-op + } } diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/Temp.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/Temp.java new file mode 100644 index 000000000..a93eb558c --- /dev/null +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/Temp.java @@ -0,0 +1,4 @@ +package org.deeplearning4j.nn.modelimport.keras; + +public class Temp { +} diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/Model.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/Model.java index 2d5314569..49b32dcc2 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/Model.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/api/Model.java @@ -233,4 +233,7 @@ public interface Model { * Apply any constraints to the model */ void applyConstraints(int iteration, int epoch); + + + void close(); } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java index 571afea7b..2f7bd45ee 100755 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/graph/ComputationGraph.java @@ -4824,4 +4824,28 @@ public class ComputationGraph implements Serializable, Model, NeuralNetwork { if (cg.getUpdater() != null && cg.getUpdater(false).getStateViewArray() != null) this.getUpdater(true).getStateViewArray().assign(cg.getUpdater(false).getStateViewArray()); } + + /** + * Close the network and deallocate all native memory, including: parameters, gradients, updater memory and workspaces + * Note that the network should not be used again for any purpose after it has been closed + */ + @Override + public void close(){ + //Close the INDArray and dealloc + if(flattenedParams.closeable()) + flattenedParams.close(); + + if(flattenedGradients != null && flattenedGradients.closeable()) + flattenedGradients.close(); + + Updater u = getUpdater(false); + if(u != null && u.getStateViewArray() != null) { + INDArray state = u.getStateViewArray(); + if(state.closeable()) + state.close(); + } + + Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread(); + System.gc(); + } } diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java index da03b19c5..f3d009a6c 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/nn/layers/AbstractLayer.java @@ -428,4 +428,9 @@ public abstract class AbstractLayer