Use DL4J workspaces for SameDiff layers in MLN/CG (#23)

* #8329 DL4J workspace integration for SameDiff layers

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix bug for Nd4j.createUninitializedDetached for scalars (length 0 shape array)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* SameDiff output layer, graph vertex, various fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-11-02 17:42:01 +11:00 committed by GitHub
parent e9a7a13c00
commit 9efd811508
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 760 additions and 545 deletions

View File

@ -96,8 +96,8 @@ public class TestBatchNormBp {
bn.setInput(in, LayerWorkspaceMgr.noWorkspaces()); bn.setInput(in, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient,INDArray> p = net.backpropGradient(eps, LayerWorkspaceMgr.noWorkspaces());
h.preOutput(in, true, new int[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces()); h.preOutput(in, true, new long[]{1,3}, gamma, beta, mean, var, 0.5, e, LayerWorkspaceMgr.noWorkspaces());
Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new int[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces()); Pair<Gradient,INDArray> pmkl = h.backpropGradient(in, eps, new long[]{1,3}, gamma, beta, dLdg, dLdb, e, LayerWorkspaceMgr.noWorkspaces());
INDArray dldin_dl4j = p.getSecond(); INDArray dldin_dl4j = p.getSecond();

View File

@ -80,154 +80,159 @@ public class TestSameDiffDense extends BaseDL4JTest {
@Test @Test
public void testSameDiffDenseForward() { public void testSameDiffDenseForward() {
for (int minibatch : new int[]{5, 1}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
int nIn = 3; for (int minibatch : new int[]{5, 1}) {
int nOut = 4; int nIn = 3;
int nOut = 4;
Activation[] afns = new Activation[]{ Activation[] afns = new Activation[]{
Activation.TANH, Activation.TANH,
Activation.SIGMOID, Activation.SIGMOID,
Activation.ELU, Activation.ELU,
Activation.IDENTITY, Activation.IDENTITY,
Activation.SOFTPLUS, Activation.SOFTPLUS,
Activation.SOFTSIGN, Activation.SOFTSIGN,
Activation.CUBE, Activation.CUBE,
Activation.HARDTANH, Activation.HARDTANH,
Activation.RELU Activation.RELU
}; };
for (Activation a : afns) { for (Activation a : afns) {
log.info("Starting test - " + a); log.info("Starting test - " + a + ", workspace = " + wsm);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.list() .inferenceWorkspaceMode(wsm)
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .trainingWorkspaceMode(wsm)
.activation(a) .list()
.build()) .layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
.build(); .activation(a)
.build())
.build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertNotNull(net.paramTable()); assertNotNull(net.paramTable());
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.list() .list()
.layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build()) .layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build())
.build(); .build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
net.params().assign(net2.params()); net.params().assign(net2.params());
//Check params: //Check params:
assertEquals(net2.params(), net.params()); assertEquals(net2.params(), net.params());
Map<String, INDArray> params1 = net.paramTable(); Map<String, INDArray> params1 = net.paramTable();
Map<String, INDArray> params2 = net2.paramTable(); Map<String, INDArray> params2 = net2.paramTable();
assertEquals(params2, params1); assertEquals(params2, params1);
INDArray in = Nd4j.rand(minibatch, nIn); INDArray in = Nd4j.rand(minibatch, nIn);
INDArray out = net.output(in); INDArray out = net.output(in);
INDArray outExp = net2.output(in); INDArray outExp = net2.output(in);
assertEquals(outExp, out); assertEquals(outExp, out);
//Also check serialization: //Also check serialization:
MultiLayerNetwork netLoaded = TestUtils.testModelSerialization(net); MultiLayerNetwork netLoaded = TestUtils.testModelSerialization(net);
INDArray outLoaded = netLoaded.output(in); INDArray outLoaded = netLoaded.output(in);
assertEquals(outExp, outLoaded); assertEquals(outExp, outLoaded);
//Sanity check on different minibatch sizes: //Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in); INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = net.output(newIn); INDArray outMbsd = net.output(newIn);
INDArray outMb = net2.output(newIn); INDArray outMb = net2.output(newIn);
assertEquals(outMb, outMbsd); assertEquals(outMb, outMbsd);
}
} }
} }
} }
@Test @Test
public void testSameDiffDenseForwardMultiLayer() { public void testSameDiffDenseForwardMultiLayer() {
for (int minibatch : new int[]{5, 1}) { for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
int nIn = 3; for (int minibatch : new int[]{5, 1}) {
int nOut = 4; int nIn = 3;
int nOut = 4;
Activation[] afns = new Activation[]{ Activation[] afns = new Activation[]{
Activation.TANH, Activation.TANH,
Activation.SIGMOID, Activation.SIGMOID,
Activation.ELU, Activation.ELU,
Activation.IDENTITY, Activation.IDENTITY,
Activation.SOFTPLUS, Activation.SOFTPLUS,
Activation.SOFTSIGN, Activation.SOFTSIGN,
Activation.CUBE, //https://github.com/deeplearning4j/nd4j/issues/2426 Activation.CUBE, //https://github.com/deeplearning4j/nd4j/issues/2426
Activation.HARDTANH, Activation.HARDTANH,
Activation.RELU //JVM crash Activation.RELU //JVM crash
}; };
for (Activation a : afns) { for (Activation a : afns) {
log.info("Starting test - " + a); log.info("Starting test - " + a + " - workspace=" + wsm);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.list() .list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut) .layer(new SameDiffDense.Builder().nIn(nIn).nOut(nOut)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(a).build()) .activation(a).build())
.layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut) .layer(new SameDiffDense.Builder().nIn(nOut).nOut(nOut)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(a).build()) .activation(a).build())
.layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut) .layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.activation(a).build()) .activation(a).build())
.validateOutputLayerConfig(false) .validateOutputLayerConfig(false)
.build(); .build();
MultiLayerNetwork net = new MultiLayerNetwork(conf); MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init(); net.init();
assertNotNull(net.paramTable()); assertNotNull(net.paramTable());
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.list() .list()
.layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build()) .layer(new DenseLayer.Builder().activation(a).nIn(nIn).nOut(nOut).build())
.layer(new DenseLayer.Builder().activation(a).nIn(nOut).nOut(nOut).build()) .layer(new DenseLayer.Builder().activation(a).nIn(nOut).nOut(nOut).build())
.layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut) .layer(new OutputLayer.Builder().nIn(nOut).nOut(nOut)
.activation(a).build()) .activation(a).build())
.validateOutputLayerConfig(false) .validateOutputLayerConfig(false)
.build(); .build();
MultiLayerNetwork net2 = new MultiLayerNetwork(conf2); MultiLayerNetwork net2 = new MultiLayerNetwork(conf2);
net2.init(); net2.init();
// net.params().assign(net2.params()); assertEquals(net2.params(), net.params());
assertEquals(net2.params(), net.params());
//Check params: //Check params:
assertEquals(net2.params(), net.params()); assertEquals(net2.params(), net.params());
Map<String, INDArray> params1 = net.paramTable(); Map<String, INDArray> params1 = net.paramTable();
Map<String, INDArray> params2 = net2.paramTable(); Map<String, INDArray> params2 = net2.paramTable();
assertEquals(params2, params1); assertEquals(params2, params1);
INDArray in = Nd4j.rand(minibatch, nIn); INDArray in = Nd4j.rand(minibatch, nIn);
INDArray out = net.output(in); INDArray out = net.output(in);
INDArray outExp = net2.output(in); INDArray outExp = net2.output(in);
assertEquals(outExp, out); assertEquals(outExp, out);
//Also check serialization: //Also check serialization:
MultiLayerNetwork netLoaded = TestUtils.testModelSerialization(net); MultiLayerNetwork netLoaded = TestUtils.testModelSerialization(net);
INDArray outLoaded = netLoaded.output(in); INDArray outLoaded = netLoaded.output(in);
assertEquals(outExp, outLoaded); assertEquals(outExp, outLoaded);
//Sanity check different minibatch sizes //Sanity check different minibatch sizes
in = Nd4j.rand(2 * minibatch, nIn); in = Nd4j.rand(2 * minibatch, nIn);
out = net.output(in); out = net.output(in);
outExp = net2.output(in); outExp = net2.output(in);
assertEquals(outExp, out); assertEquals(outExp, out);
}
} }
} }
} }
@ -244,10 +249,13 @@ public class TestSameDiffDense extends BaseDL4JTest {
Activation[] afns = new Activation[]{ Activation[] afns = new Activation[]{
Activation.TANH, Activation.TANH,
Activation.SIGMOID, Activation.SIGMOID,
Activation.ELU, Activation.IDENTITY, Activation.SOFTPLUS, Activation.SOFTSIGN, Activation.ELU,
Activation.IDENTITY,
Activation.SOFTPLUS,
Activation.SOFTSIGN,
Activation.HARDTANH, Activation.HARDTANH,
Activation.CUBE, //https://github.com/deeplearning4j/nd4j/issues/2426 Activation.CUBE,
Activation.RELU //JVM crash Activation.RELU
}; };
for (Activation a : afns) { for (Activation a : afns) {
@ -337,64 +345,66 @@ public class TestSameDiffDense extends BaseDL4JTest {
int nIn = 4; int nIn = 4;
int nOut = 3; int nOut = 3;
boolean workspaces = true;
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
.seed(12345)
.trainingWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.inferenceWorkspaceMode(workspaces ? WorkspaceMode.ENABLED : WorkspaceMode.NONE)
.updater(new Adam(0.1))
.list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
.layer(new SameDiffDense.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new OutputLayer.Builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork netSD = new MultiLayerNetwork(conf); MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
netSD.init(); .seed(12345)
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.updater(new Adam(0.1))
.list()
.layer(new SameDiffDense.Builder().nIn(nIn).nOut(5).activation(Activation.TANH).build())
.layer(new SameDiffDense.Builder().nIn(5).nOut(5).activation(Activation.TANH).build())
.layer(new OutputLayer.Builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder() MultiLayerNetwork netSD = new MultiLayerNetwork(conf);
.seed(12345) netSD.init();
.updater(new Adam(0.1))
.list()
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(nIn).nOut(5).build())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(5).build())
.layer(new OutputLayer.Builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2); MultiLayerConfiguration conf2 = new NeuralNetConfiguration.Builder()
netStandard.init(); .seed(12345)
.updater(new Adam(0.1))
.list()
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(nIn).nOut(5).build())
.layer(new DenseLayer.Builder().activation(Activation.TANH).nIn(5).nOut(5).build())
.layer(new OutputLayer.Builder().nIn(5).nOut(nOut).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build())
.build();
netSD.params().assign(netStandard.params()); MultiLayerNetwork netStandard = new MultiLayerNetwork(conf2);
netStandard.init();
//Check params: netSD.params().assign(netStandard.params());
assertEquals(netStandard.params(), netSD.params());
assertEquals(netStandard.paramTable(), netSD.paramTable());
DataSetIterator iter = new IrisDataSetIterator(150,150); //Check params:
DataSet ds = iter.next(); assertEquals(netStandard.params(), netSD.params());
assertEquals(netStandard.paramTable(), netSD.paramTable());
INDArray outSD = netSD.output(ds.getFeatures()); DataSetIterator iter = new IrisDataSetIterator(150, 150);
INDArray outStd = netStandard.output(ds.getFeatures()); DataSet ds = iter.next();
assertEquals(outStd, outSD); INDArray outSD = netSD.output(ds.getFeatures());
INDArray outStd = netStandard.output(ds.getFeatures());
for( int i=0; i<3; i++ ){ assertEquals(outStd, outSD);
netSD.fit(ds);
netStandard.fit(ds); for (int i = 0; i < 3; i++) {
String s = String.valueOf(i); netSD.fit(ds);
assertEquals(s, netStandard.getFlattenedGradients(), netSD.getFlattenedGradients()); netStandard.fit(ds);
assertEquals(s, netStandard.params(), netSD.params()); String s = String.valueOf(i);
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray()); assertEquals(s, netStandard.getFlattenedGradients(), netSD.getFlattenedGradients());
assertEquals(s, netStandard.params(), netSD.params());
assertEquals(s, netStandard.getUpdater().getStateViewArray(), netSD.getUpdater().getStateViewArray());
}
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
} }
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(ds.getFeatures(), ds.getFeatures());
INDArray outMbsd = netSD.output(newIn);
INDArray outMb = netStandard.output(newIn);
assertEquals(outMb, outMbsd);
} }
@Test @Test
@ -402,7 +412,7 @@ public class TestSameDiffDense extends BaseDL4JTest {
int nIn = 4; int nIn = 4;
int nOut = 4; int nOut = 4;
for (boolean workspaces : new boolean[]{false, true}) { for (boolean workspaces : new boolean[]{true, false}) {
for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) { for (Activation a : new Activation[]{Activation.TANH, Activation.IDENTITY}) {
String msg = "workspaces: " + workspaces + ", " + a; String msg = "workspaces: " + workspaces + ", " + a;

View File

@ -21,6 +21,7 @@ import org.deeplearning4j.BaseDL4JTest;
import org.deeplearning4j.TestUtils; import org.deeplearning4j.TestUtils;
import org.deeplearning4j.nn.conf.ComputationGraphConfiguration; import org.deeplearning4j.nn.conf.ComputationGraphConfiguration;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration; import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.WorkspaceMode;
import org.deeplearning4j.nn.conf.graph.ElementWiseVertex; import org.deeplearning4j.nn.conf.graph.ElementWiseVertex;
import org.deeplearning4j.nn.conf.graph.ScaleVertex; import org.deeplearning4j.nn.conf.graph.ScaleVertex;
import org.deeplearning4j.nn.conf.graph.ShiftVertex; import org.deeplearning4j.nn.conf.graph.ShiftVertex;
@ -52,152 +53,169 @@ public class TestSameDiffLambda extends BaseDL4JTest {
@Test @Test
public void testSameDiffLamdaLayerBasic(){ public void testSameDiffLamdaLayerBasic(){
Nd4j.getRandom().setSeed(12345); for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() log.info("--- Workspace Mode: {} ---", wsm);
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in")
.addLayer("1", new SameDiffSimpleLambdaLayer(), "0")
.addLayer("2", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "1")
.setOutputs("2")
.build();
//Equavalent, not using SameDiff Lambda:
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in")
.addVertex("1", new ShiftVertex(1.0), "0")
.addVertex("2", new ScaleVertex(2.0), "1")
.addLayer("3", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "2")
.setOutputs("3")
.build();
ComputationGraph lambda = new ComputationGraph(conf); Nd4j.getRandom().setSeed(12345);
lambda.init(); ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in")
.addLayer("1", new SameDiffSimpleLambdaLayer(), "0")
.addLayer("2", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "1")
.setOutputs("2")
.build();
ComputationGraph std = new ComputationGraph(confStd); //Equavalent, not using SameDiff Lambda:
std.init(); ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in")
.addVertex("1", new ShiftVertex(1.0), "0")
.addVertex("2", new ScaleVertex(2.0), "1")
.addLayer("3", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "2")
.setOutputs("3")
.build();
lambda.setParams(std.params()); ComputationGraph lambda = new ComputationGraph(conf);
lambda.init();
INDArray in = Nd4j.rand(3,5); ComputationGraph std = new ComputationGraph(confStd);
INDArray labels = TestUtils.randomOneHot(3, 5); std.init();
DataSet ds = new DataSet(in, labels);
INDArray outLambda = lambda.outputSingle(in); lambda.setParams(std.params());
INDArray outStd = std.outputSingle(in);
assertEquals(outLambda, outStd); INDArray in = Nd4j.rand(3, 5);
INDArray labels = TestUtils.randomOneHot(3, 5);
DataSet ds = new DataSet(in, labels);
double scoreLambda = lambda.score(ds); INDArray outLambda = lambda.outputSingle(in);
double scoreStd = std.score(ds); INDArray outStd = std.outputSingle(in);
assertEquals(scoreStd, scoreLambda, 1e-6); assertEquals(outLambda, outStd);
for( int i=0; i<3; i++ ){ double scoreLambda = lambda.score(ds);
lambda.fit(ds); double scoreStd = std.score(ds);
std.fit(ds);
String s = String.valueOf(i); assertEquals(scoreStd, scoreLambda, 1e-6);
assertEquals(s, std.params(), lambda.params());
assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients()); for (int i = 0; i < 3; i++) {
lambda.fit(ds);
std.fit(ds);
String s = String.valueOf(i);
assertEquals(s, std.params(), lambda.params());
assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients());
}
ComputationGraph loaded = TestUtils.testModelSerialization(lambda);
outLambda = loaded.outputSingle(in);
outStd = std.outputSingle(in);
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = lambda.output(newIn)[0];
INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd);
} }
ComputationGraph loaded = TestUtils.testModelSerialization(lambda);
outLambda = loaded.outputSingle(in);
outStd = std.outputSingle(in);
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn = Nd4j.vstack(in, in);
INDArray outMbsd = lambda.output(newIn)[0];
INDArray outMb = std.output(newIn)[0];
assertEquals(outMb, outMbsd);
} }
@Test @Test
public void testSameDiffLamdaVertexBasic(){ public void testSameDiffLamdaVertexBasic(){
Nd4j.getRandom().setSeed(12345); for(WorkspaceMode wsm : new WorkspaceMode[]{WorkspaceMode.ENABLED, WorkspaceMode.NONE}) {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder() log.info("--- Workspace Mode: {} ---", wsm);
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in1", "in2")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in1")
.addLayer("1", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in2")
.addVertex("lambda", new SameDiffSimpleLambdaVertex(), "0", "1")
.addLayer("2", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "lambda")
.setOutputs("2")
.build();
//Equavalent, not using SameDiff Lambda: Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder() ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder()
.dataType(DataType.DOUBLE) .trainingWorkspaceMode(wsm)
.seed(12345) .inferenceWorkspaceMode(wsm)
.updater(new Adam(0.01)) .dataType(DataType.DOUBLE)
.graphBuilder() .seed(12345)
.addInputs("in1", "in2") .updater(new Adam(0.01))
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in1") .graphBuilder()
.addLayer("1", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in2") .addInputs("in1", "in2")
.addVertex("elementwise", new ElementWiseVertex(ElementWiseVertex.Op.Product), "0", "1") .addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in1")
.addLayer("3", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX) .addLayer("1", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in2")
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "elementwise") .addVertex("lambda", new SameDiffSimpleLambdaVertex(), "0", "1")
.setOutputs("3") .addLayer("2", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.build(); .lossFunction(LossFunctions.LossFunction.MCXENT).build(), "lambda")
.setOutputs("2")
.build();
ComputationGraph lambda = new ComputationGraph(conf); //Equavalent, not using SameDiff Lambda:
lambda.init(); ComputationGraphConfiguration confStd = new NeuralNetConfiguration.Builder()
.trainingWorkspaceMode(wsm)
.inferenceWorkspaceMode(wsm)
.dataType(DataType.DOUBLE)
.seed(12345)
.updater(new Adam(0.01))
.graphBuilder()
.addInputs("in1", "in2")
.addLayer("0", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in1")
.addLayer("1", new DenseLayer.Builder().nIn(5).nOut(5).activation(Activation.TANH).build(), "in2")
.addVertex("elementwise", new ElementWiseVertex(ElementWiseVertex.Op.Product), "0", "1")
.addLayer("3", new OutputLayer.Builder().nIn(5).nOut(5).activation(Activation.SOFTMAX)
.lossFunction(LossFunctions.LossFunction.MCXENT).build(), "elementwise")
.setOutputs("3")
.build();
ComputationGraph std = new ComputationGraph(confStd); ComputationGraph lambda = new ComputationGraph(conf);
std.init(); lambda.init();
lambda.setParams(std.params()); ComputationGraph std = new ComputationGraph(confStd);
std.init();
INDArray in1 = Nd4j.rand(3,5); lambda.setParams(std.params());
INDArray in2 = Nd4j.rand(3,5);
INDArray labels = TestUtils.randomOneHot(3, 5);
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{in1, in2}, new INDArray[]{labels});
INDArray outLambda = lambda.output(in1, in2)[0]; INDArray in1 = Nd4j.rand(3, 5);
INDArray outStd = std.output(in1, in2)[0]; INDArray in2 = Nd4j.rand(3, 5);
INDArray labels = TestUtils.randomOneHot(3, 5);
MultiDataSet mds = new org.nd4j.linalg.dataset.MultiDataSet(new INDArray[]{in1, in2}, new INDArray[]{labels});
assertEquals(outLambda, outStd); INDArray outLambda = lambda.output(in1, in2)[0];
INDArray outStd = std.output(in1, in2)[0];
double scoreLambda = lambda.score(mds); assertEquals(outLambda, outStd);
double scoreStd = std.score(mds);
assertEquals(scoreStd, scoreLambda, 1e-6); double scoreLambda = lambda.score(mds);
double scoreStd = std.score(mds);
for( int i=0; i<3; i++ ){ assertEquals(scoreStd, scoreLambda, 1e-6);
lambda.fit(mds);
std.fit(mds);
String s = String.valueOf(i); for (int i = 0; i < 3; i++) {
assertEquals(s, std.params(), lambda.params()); lambda.fit(mds);
assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients()); std.fit(mds);
String s = String.valueOf(i);
assertEquals(s, std.params(), lambda.params());
assertEquals(s, std.getFlattenedGradients(), lambda.getFlattenedGradients());
}
ComputationGraph loaded = TestUtils.testModelSerialization(lambda);
outLambda = loaded.output(in1, in2)[0];
outStd = std.output(in1, in2)[0];
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn1 = Nd4j.vstack(in1, in1);
INDArray newIn2 = Nd4j.vstack(in2, in2);
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
INDArray outMb = std.output(newIn1, newIn2)[0];
assertEquals(outMb, outMbsd);
} }
ComputationGraph loaded = TestUtils.testModelSerialization(lambda);
outLambda = loaded.output(in1, in2)[0];
outStd = std.output(in1, in2)[0];
assertEquals(outStd, outLambda);
//Sanity check on different minibatch sizes:
INDArray newIn1 = Nd4j.vstack(in1, in1);
INDArray newIn2 = Nd4j.vstack(in2, in2);
INDArray outMbsd = lambda.output(newIn1, newIn2)[0];
INDArray outMb = std.output(newIn1, newIn2)[0];
assertEquals(outMb, outMbsd);
} }
} }

View File

@ -0,0 +1,68 @@
package org.deeplearning4j.nn.layers.samediff;
import org.nd4j.autodiff.samediff.internal.memory.AbstractMemoryMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
/**
* A SameDiff {@link org.nd4j.autodiff.samediff.internal.SessionMemMgr} that uses DL4J workspaces for memory management.
* Any op outputs are allocated in the output workspace if they are returned to the layer; otherwise they are placed in
* the DL4J working memory workspace
*
* @author Alex Black
*/
public class DL4JSameDiffMemoryMgr extends AbstractMemoryMgr {
private final String workingMemoryWs;
private final String outputWs;
private final WorkspaceConfiguration confWorking;
private final WorkspaceConfiguration confOutput;
//Note: if the working memory or output workspace names are null -> detached memory
public DL4JSameDiffMemoryMgr(String workingMemoryWs, String outputWs, WorkspaceConfiguration confWorking,
WorkspaceConfiguration confOutput){
this.workingMemoryWs = workingMemoryWs;
this.outputWs = outputWs;
this.confWorking = confWorking;
this.confOutput = confOutput;
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
String wsName = detached ? outputWs : workingMemoryWs;
WorkspaceConfiguration wsConf = detached ? confOutput : confWorking;
if(wsName == null){
//Scoped out
INDArray ret = Nd4j.createUninitializedDetached(dataType, shape);
Preconditions.checkState(!ret.isAttached(), "Returned array should be detached");
return ret;
} else {
MemoryWorkspace ws = Nd4j.getWorkspaceManager().getWorkspaceForCurrentThread(wsConf, wsName);
try (MemoryWorkspace mw = ws.notifyScopeBorrowed()) {
return Nd4j.createUninitialized(dataType, shape);
}
}
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return allocate(detached, descriptor.dataType(), descriptor.getShape());
}
@Override
public void release(INDArray array) {
//No-op - DL4J workspaces handles this
}
@Override
public void close() {
//No-op - DL4J workspaces handles this
}
}

View File

@ -31,9 +31,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -95,119 +98,159 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
@Override @Override
public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) { public INDArray doForward(boolean training, LayerWorkspaceMgr workspaceMgr) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if (sameDiff == null) {
doInit(); doInit();
} }
Map<String,INDArray> phMap = new HashMap<>();
config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
final String maskName = name + "_mask";
phMap.put(name, inputs[i]);
if(maskArrays != null && maskArrays[i] != null) {
phMap.put(maskName, maskArrays[i]);
}else{
phMap.put(maskName, createMask(dataType, inputs[i].shape()));
}
}
if(paramTable != null && paramTable.size() > 0) {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
}
INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} }
Map<String,INDArray> phMap = new HashMap<>();
config.validateInput(inputs);
for(int i=0; i<inputs.length; i++ ){
String name = config.getVertexParams().getInputs().get(i);
final String maskName = name + "_mask";
phMap.put(name, inputs[i]);
if(maskArrays != null && maskArrays[i] != null) {
phMap.put(maskName, maskArrays[i]);
}else{
phMap.put(maskName, createMask(dataType, inputs[i].shape()));
}
}
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
if(paramTable != null && paramTable.size() > 0) {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
}
INDArray result = sameDiff.outputSingle(phMap, outputKey);
//Edge case: "vertex" is just an identity activation, for example
//TODO there may be a cleaner way to do this...
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} else if(actScopedOut && result.isAttached()){
result = result.detach();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} }
@Override @Override
public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) { public Pair<Gradient, INDArray[]> doBackward(boolean tbptt, LayerWorkspaceMgr workspaceMgr) {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray[] dLdIns; try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
boolean[] noClose = new boolean[getNumInputArrays()]; if (sameDiff == null) {
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
if(sameDiff == null){
doInit(); doInit();
} }
List<String> inputNames = config.getVertexParams().getInputs();
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
String[] inArr = inputNames.toArray(new String[inputNames.size()]);
sameDiff.createGradFunction(inArr);
}
config.validateInput(inputs);
Map<String,INDArray> phMap = new HashMap<>();
List<String> inputs = config.getVertexParams().getInputs();
int i=0;
for(String s : inputs){
phMap.put(s, this.inputs[i++]);
}
for( int j=0; j<this.inputs.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask";
if(maskArrays != null && maskArrays[j] != null) {
phMap.put(maskName, maskArrays[j]);
}else{
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
}
}
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
required.addAll(paramTable.keySet());
required.addAll(inputNames);
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
for(String s : paramTable.keySet() ){
INDArray sdGrad = gradsMap.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
sdGrad.close(); //TODO optimize this
g.gradientForVariable().put(s, dl4jGrad);
}
dLdIns = new INDArray[inputs.size()];
String fnName = fn.getGradPlaceholderName();
for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).name();
if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon;
noClose[j] = true;
}
}
} }
//TODO optimize List<String> inputNames = config.getVertexParams().getInputs();
for( int i=0; i<dLdIns.length; i++ ){ if(!sameDiff.hasGradientFunction()) {
INDArray before = dLdIns[i]; //Create when scoped out, to ensure any arrays are not in WS
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]); String[] inArr = inputNames.toArray(new String[inputNames.size()]);
if(!noClose[i]){ sameDiff.createGradFunction(inArr);
before.close(); }
config.validateInput(inputs);
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
Map<String,INDArray> phMap = new HashMap<>();
List<String> inputs = config.getVertexParams().getInputs();
int i=0;
for(String s : inputs){
phMap.put(s, this.inputs[i++]);
}
for( int j=0; j<this.inputs.length; j++ ){
String name = inputs.get(j);
final String maskName = name + "_mask";
if(maskArrays != null && maskArrays[j] != null) {
phMap.put(maskName, maskArrays[j]);
}else{
phMap.put(maskName, createMask(dataType, this.inputs[j].shape()));
}
}
String epsName = fn.getGradPlaceholderName();
phMap.put(epsName, epsilon);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
required.addAll(paramTable.keySet());
required.addAll(inputNames);
Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
for(String s : paramTable.keySet() ){
INDArray sdGrad = gradsMap.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad);
}
INDArray[] dLdIns = new INDArray[inputs.size()];
String fnName = fn.getGradPlaceholderName();
for(int j=0; j<inputs.size(); j++ ){
String name = inputs.get(j);
dLdIns[j] = sameDiff.grad(name).getArr();
String gradName = sameDiff.grad(inputNames.get(j)).name();
if(dLdIns[j] == null && fnName.equals(gradName)){
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon;
}
//Edge case: "vertex" is just an identity activation, for example
//TODO there may be a cleaner way to do this...
if(!actGradScopedOut && !dLdIns[j].data().getParentWorkspace().getId().equals(wsNameActGrad)){
dLdIns[j] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[j]);
} else if(actGradScopedOut && dLdIns[j].isAttached()){
dLdIns[j] = dLdIns[j].detach();
} }
} }

View File

@ -26,9 +26,12 @@ import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer; import org.deeplearning4j.nn.layers.AbstractLayer;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4j;
@ -81,43 +84,62 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
assertInputSet(false); assertInputSet(false);
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if (sameDiff == null) {
doInit(); doInit();
} }
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
} else {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
INDArray result = out.get(outputKey);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
INDArray ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
if(!result.isAttached() && result.closeable()) {
//May be attached in rare edge case - for identity, or if gradients are passed through from output to input
// unchaned, as in identity, add scalar, etc
result.close();
}
return ret;
} }
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
} else {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
Map<String,INDArray> out = sameDiff.output(phMap, outputKey);
INDArray result = out.get(outputKey);
//Edge case - identity activation
//TODO there may be a cleaner way to do this...
if(!actScopedOut && !result.data().getParentWorkspace().getId().equals(wsNameOutput)){
result = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
} else if(actScopedOut && result.isAttached()){
result = result.detach();
}
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
return result;
} }
@ -128,67 +150,71 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray dLdIn; INDArray dLdIn;
boolean noCloseEps = false;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if (sameDiff == null) {
doInit(); doInit();
} }
if(!sameDiff.hasGradientFunction()) { if (!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS //Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
} }
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
} else {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(INPUT_KEY);
requiredGrads.addAll(paramTable.keySet());
Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
for(String s : paramTable.keySet() ){
INDArray sdGrad = m.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad);
sdGrad.close();
}
dLdIn = m.get(INPUT_KEY);
if(dLdIn == null && fn.getGradPlaceholderName().equals(INPUT_KEY)){
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIn = epsilon;
noCloseEps = true;
}
} }
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer bl = (org.deeplearning4j.nn.conf.layers.samediff.SameDiffLayer) layerConf();
bl.validateInput(input);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(fn.getGradPlaceholderName(), epsilon);
if(maskArray != null){
phMap.put(MASK_KEY, maskArray);
} else {
phMap.put(MASK_KEY, layerConf().onesMaskForInput(input));
}
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(INPUT_KEY);
requiredGrads.addAll(paramTable.keySet());
Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
for(String s : paramTable.keySet() ){
INDArray sdGrad = m.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad);
}
dLdIn = m.get(INPUT_KEY);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
if(!noCloseEps && !dLdIn.isAttached() && dLdIn.closeable()) {
//Edge case: identity etc - might just pass gradient array through unchanged
dLdIn.close();
}
return ret; return ret;
} }

View File

@ -29,9 +29,12 @@ import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr; import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.dataset.api.DataSet; import org.nd4j.linalg.dataset.api.DataSet;
@ -95,40 +98,59 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
//TODO optimize //TODO optimize
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) { try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if (sameDiff == null) {
doInit(); doInit();
} }
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) {
phMap.put(LABELS_KEY, labels);
}
String s = activations ? layerConf().activationsVertexName() : outputVar.name();
INDArray out = sameDiff.outputSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
if(activations) {
Preconditions.checkNotNull(out, "Activations (result) array for variable \"%s\" was " +
"null - error during execution or this variable (as defined by method activationsVertexName()) " +
"does not exist", layerConf().activationsVertexName());
return workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
} else {
return out;
}
} }
//Configure memory management for SameDiff instance - use DL4J workspaces
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.FF_WORKING_MEM);
String wsNameOutput = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATIONS);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.FF_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATIONS);
boolean actScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATIONS);
Preconditions.checkState(actScopedOut || wsNameOutput != null, "Activations must have a workspace or must be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameOutput, confWorking, confOutput);
InferenceSession is = sameDiff.getSessions().get(Thread.currentThread().getId());
if(is == null){
is = new InferenceSession(sameDiff);
sameDiff.getSessions().put(Thread.currentThread().getId(), is);
}
is.setMmgr(mmgr);
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
if(!activations && layerConf().labelsRequired() && labels != null) {
phMap.put(LABELS_KEY, labels);
}
String s = activations ? layerConf().activationsVertexName() : outputVar.name();
INDArray out = sameDiff.outputSingle(phMap, s);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs();
//Edge case: vertex is just an Identity function, for example
//TODO there may be a cleaner way to do this...
if(!actScopedOut && !out.data().getParentWorkspace().getId().equals(wsNameOutput)){
out = workspaceMgr.dup(ArrayType.ACTIVATIONS, out);
} else if(actScopedOut && out.isAttached()){
out = out.detach();
}
return out;
} }
@ -141,54 +163,76 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray dLdIn; INDArray dLdIn;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()) {
if(sameDiff == null){ if (sameDiff == null) {
//Usually doInit will be called in forward pass; not necessarily the case in output layers //Usually doInit will be called in forward pass; not necessarily the case in output layers
// (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph) // (for efficiency, we skip output layer forward pass in MultiLayerNetwork/ComputationGraph)
doInit(); doInit();
} }
if(!sameDiff.hasGradientFunction()) { if(sameDiff.getFunction("grad") == null)
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY); sameDiff.createGradFunction(INPUT_KEY);
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> gradVarNames = new ArrayList<>();
gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(INPUT_KEY);
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels);
Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
for(String s : paramTable.keySet() ){
INDArray sdGrad = grads.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad);
if(sdGrad.closeable()){
sdGrad.close();
}
}
dLdIn = grads.get(INPUT_KEY);
} }
//Configure memory management for SameDiff instance - use DL4J workspaces
Map<Long,InferenceSession> sessionMap = sameDiff.getFunction("grad").getSessions();
if(!sessionMap.containsKey(Thread.currentThread().getId())){
sessionMap.put(Thread.currentThread().getId(), new InferenceSession(sameDiff.getFunction("grad")));
}
String wsNameWorking = workspaceMgr.getWorkspaceName(ArrayType.BP_WORKING_MEM);
String wsNameActGrad = workspaceMgr.getWorkspaceName(ArrayType.ACTIVATION_GRAD);
WorkspaceConfiguration confWorking = workspaceMgr.getConfiguration(ArrayType.BP_WORKING_MEM);
WorkspaceConfiguration confOutput = workspaceMgr.getConfiguration(ArrayType.ACTIVATION_GRAD);
boolean actGradScopedOut = workspaceMgr.isScopedOut(ArrayType.ACTIVATION_GRAD);
Preconditions.checkState(actGradScopedOut || wsNameActGrad != null, "Activation gradients must have a workspace or be scoped out");
SessionMemMgr mmgr = new DL4JSameDiffMemoryMgr(wsNameWorking, wsNameActGrad, confWorking, confOutput);
sessionMap.get(Thread.currentThread().getId()).setMmgr(mmgr);
if(!sameDiff.hasGradientFunction()) {
//Create when scoped out, to ensure any arrays are not in WS
sameDiff.createGradFunction(INPUT_KEY);
}
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
}
List<String> gradVarNames = new ArrayList<>();
gradVarNames.addAll(paramTable.keySet());
gradVarNames.add(INPUT_KEY);
Map<String,INDArray> phMap = new HashMap<>();
phMap.put(INPUT_KEY, input);
phMap.put(LABELS_KEY, labels);
Map<String,INDArray> grads = sameDiff.calculateGradients(phMap, gradVarNames);
for(String s : paramTable.keySet() ){
INDArray sdGrad = grads.get(s);
INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad);
if(sdGrad.closeable()){
sdGrad.close();
}
}
dLdIn = grads.get(INPUT_KEY);
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
Pair<Gradient,INDArray> p = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS //TODO there may be a cleaner way to do this...
if(dLdIn.closeable()) if(!actGradScopedOut && !dLdIn.data().getParentWorkspace().getId().equals(wsNameActGrad)){
dLdIn.close(); dLdIn = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn);
return p; } else if(actGradScopedOut && dLdIn.isAttached()){
dLdIn = dLdIn.detach();
}
return new Pair<>(g, dLdIn);
} }
/**Returns the parameters of the neural network as a flattened row vector /**Returns the parameters of the neural network as a flattened row vector
@ -312,7 +356,8 @@ public class SameDiffOutputLayer extends AbstractLayer<org.deeplearning4j.nn.con
@Override @Override
public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) { public double computeScore(double fullNetRegTerm, boolean training, LayerWorkspaceMgr workspaceMgr) {
return (activateHelper(false, workspaceMgr).getDouble(0) + fullNetRegTerm) / input.size(0); INDArray scoreArr = activateHelper(false, workspaceMgr);
return (scoreArr.getDouble(0) + fullNetRegTerm) / input.size(0);
} }
@Override @Override

View File

@ -309,11 +309,11 @@ public abstract class BaseNDArray implements INDArray, Iterable {
* @param ordering the ordering of the ndarray * @param ordering the ordering of the ndarray
*/ */
public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) { public BaseNDArray(int[] shape, int[] stride, long offset, char ordering) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering);
} }
public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) { public BaseNDArray(long[] shape, long[] stride, long offset, char ordering) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape)), shape, stride, offset, ordering); this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape)), shape, stride, offset, ordering);
} }
/** /**
@ -326,19 +326,19 @@ public abstract class BaseNDArray implements INDArray, Iterable {
* @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't. * @param initialize Whether to initialize the INDArray. If true: initialize. If false: don't.
*/ */
public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) { public BaseNDArray(int[] shape, int[] stride, long offset, char ordering, boolean initialize) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering);
} }
public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) { public BaseNDArray(long[] shape, long[] stride, long offset, char ordering, boolean initialize) {
this(Nd4j.createBuffer(ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering); this(Nd4j.createBuffer(shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), shape, stride, offset, ordering);
} }
public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) { public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize) {
this(Nd4j.createBuffer(type, ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering); this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize), type, shape, stride, offset, ordering);
} }
public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) { public BaseNDArray(DataType type, long[] shape, long[] stride, long offset, char ordering, boolean initialize, MemoryWorkspace workspace) {
this(Nd4j.createBuffer(type, ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering); this(Nd4j.createBuffer(type, shape.length == 0 ? 1 : ArrayUtil.prodLong(shape), initialize, workspace), type, shape, stride, offset, ordering);
} }

View File

@ -319,6 +319,11 @@ public class BasicWorkspaceTests extends BaseNd4jTest {
long reqMemory = 5 * Nd4j.sizeOfDataType(array1.dataType()); long reqMemory = 5 * Nd4j.sizeOfDataType(array1.dataType());
assertEquals(reqMemory + reqMemory % 8, wsI.getPrimaryOffset()); assertEquals(reqMemory + reqMemory % 8, wsI.getPrimaryOffset());
assertEquals(array1, array2); assertEquals(array1, array2);
INDArray array3 = Nd4j.createUninitializedDetached(DataType.FLOAT, new long[0]);
assertTrue(array3.isScalar());
assertEquals(1, array3.length());
assertEquals(1, array3.data().length());
} }
} }