SameDiff execution, TF and memory management overhaul (#10)

* SameDiff execution memory management improvements, round 1

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 2

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Round 3

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clear node outputs closed array references; Slight change to OpValidation internals to not rely on cached op outputs

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next step

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add WeakIdentityHashmap

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Session fixes for control ops and next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First steps for training session + in-line updating

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix losses and history during training

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* BiasAdd and other fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Don't use SDVariable.getArr() in TFGraphTestAllHelper (import tests)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First steps for new dependency tracking approach

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Start integrating dependency tracking for memory management

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Non-control op dependency tracking works/passes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Switch/merge

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue dependency tracking for initial variables/constants

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add check for aliases when determining if safe to close array

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* First pass on new TF graph import class

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Import fixes, op fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and fixes for new TF import mapper

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Partial implementation of new dependency tracker

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* AbstractDependencyTracker for shared code

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Overhaul SameDiff graph execution (dependency tracking)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More fixes, cleanup, next steps

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Ad no-op memory manager, cleanup, fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix switch dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* INDArray.toString: no exception on closed arrays, just note closed

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix enter and exit dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* TensorArray memory management fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add unique ID for INDArray instances

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix memory management for NextIteration outputs in multi-iteration loops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Remove (now unnecessary) special case handling for nested enters

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Handle control dependencies during execution; javadoc for memory managers

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup, polish, code comments, javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and more javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Add memory validation for all TF import tests - ensure all arrays (except outputs) are released

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Clean up arrays waiting on unexecuted ops at the end of execution

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fixes for enter op memory managent in the context of multiple non-nested loops/frames

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix order of operation issues for dependency tracker

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Always clear op fields after execution to avoid leaks or unintended array reuse

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Re-implement dtype conversion

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for control dependencies execution (dependency tracking)

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix TF import overrides and filtering

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix for constant enter array dependency tracking

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* DL4J Fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More DL4J fixes

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Cleanup and polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More polish and javadoc

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* More logging level tweaks, small DL4J fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix to DL4J SameDiffLayer

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix empty array deserialization, add extra deserialization checks

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers control dep serialization fixes; test serialization as part of all TF import tests

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Variable control dependencies serialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Fix issue with removing inputs for ops

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers NDArray deserialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* FlatBuffers NDArray deserialization fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Small fix

Signed-off-by: AlexDBlack <blacka101@gmail.com>

* Final cleanup/polish

Signed-off-by: AlexDBlack <blacka101@gmail.com>
master
Alex Black 2019-10-23 21:19:50 +11:00 committed by GitHub
parent f31661e13b
commit 3f0b4a2d4c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
154 changed files with 5185 additions and 6611 deletions

View File

@ -157,8 +157,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() { public void testMultiLayerNetworkFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder() MultiLayerConfiguration conf1 = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
@ -194,8 +194,9 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testComputationGraphFrozenLayerParamsAfterBackprop() { public void testComputationGraphFrozenLayerParamsAfterBackprop() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4,12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
@ -254,43 +255,18 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
*/ */
@Test @Test
public void testFrozenLayerVsSgd() { public void testFrozenLayerVsSgd() {
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confSgd = new NeuralNetConfiguration.Builder()
.seed(12345) .seed(12345)
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(0, .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
new DenseLayer.Builder() .layer(1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build())
.nIn(4) .layer(2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build())
.nOut(3) .layer(3,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(2).nOut(1).build())
.build()
)
.layer(1,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(3)
.nOut(4)
.build()
).layer(2,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(4)
.nOut(2)
.build()
).layer(3,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.activation(Activation.TANH)
.nIn(2)
.nOut(1)
.build()
)
.build(); .build();
MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder() MultiLayerConfiguration confFrozen = new NeuralNetConfiguration.Builder()
@ -298,36 +274,10 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.weightInit(WeightInit.XAVIER) .weightInit(WeightInit.XAVIER)
.updater(new Sgd(2)) .updater(new Sgd(2))
.list() .list()
.layer(0, .layer(0,new DenseLayer.Builder().nIn(4).nOut(3).build())
new DenseLayer.Builder() .layer(1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(3).nOut(4).build()))
.nIn(4) .layer(2,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new DenseLayer.Builder().nIn(4).nOut(2).build()))
.nOut(3) .layer(3,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(2).nOut(1).build()))
.build()
)
.layer(1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
)
)
.layer(2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build()
)
).layer(3,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH)
.nIn(2)
.nOut(1)
.build()
)
)
.build(); .build();
MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen); MultiLayerNetwork frozenNetwork = new MultiLayerNetwork(confFrozen);
frozenNetwork.init(); frozenNetwork.init();
@ -359,8 +309,8 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
@Test @Test
public void testComputationGraphVsSgd() { public void testComputationGraphVsSgd() {
Nd4j.getRandom().setSeed(12345);
DataSet randomData = new DataSet(Nd4j.rand(100, 4, 12345), Nd4j.rand(100, 1, 12345)); DataSet randomData = new DataSet(Nd4j.rand(100, 4), Nd4j.rand(100, 1));
String frozenBranchName = "B1-"; String frozenBranchName = "B1-";
String unfrozenBranchName = "B2-"; String unfrozenBranchName = "B2-";
@ -381,71 +331,19 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.graphBuilder() .graphBuilder()
.addInputs("input") .addInputs("input")
.addLayer(initialLayer, .addLayer(initialLayer,new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
new DenseLayer.Builder() .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(), initialLayer)
.nIn(4) .addLayer(frozenBranchFrozenLayer1,new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
.nOut(4) new DenseLayer.Builder().nIn(3).nOut(4).build()),frozenBranchUnfrozenLayer0)
.build(),
"input"
)
.addLayer(frozenBranchUnfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(3)
.build(),
initialLayer
)
.addLayer(frozenBranchFrozenLayer1,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder()
.nIn(3)
.nOut(4)
.build()
),
frozenBranchUnfrozenLayer0
)
.addLayer(frozenBranchFrozenLayer2, .addLayer(frozenBranchFrozenLayer2,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop( new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new DenseLayer.Builder() new DenseLayer.Builder().nIn(4).nOut(2).build()),frozenBranchFrozenLayer1)
.nIn(4) .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
.nOut(2) .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
.build() .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
), .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
frozenBranchFrozenLayer1 .addLayer(frozenBranchOutput, new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
) new OutputLayer.Builder(LossFunctions.LossFunction.MSE).activation(Activation.TANH).nIn(3).nOut(1).build()),"merge")
.addLayer(unfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
initialLayer
)
.addLayer(unfrozenLayer1,
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build(),
unfrozenLayer0
)
.addLayer(unfrozenBranch2,
new DenseLayer.Builder()
.nIn(2)
.nOut(1)
.build(),
unfrozenLayer1
)
.addVertex("merge",
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,
new org.deeplearning4j.nn.conf.layers.misc.FrozenLayerWithBackprop(
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.activation(Activation.TANH)
.nIn(3)
.nOut(1)
.build()
),
"merge"
)
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();
@ -454,73 +352,15 @@ public class FrozenLayerWithBackpropTest extends BaseDL4JTest {
.seed(12345) .seed(12345)
.graphBuilder() .graphBuilder()
.addInputs("input") .addInputs("input")
.addLayer(initialLayer, .addLayer(initialLayer, new DenseLayer.Builder().nIn(4).nOut(4).build(),"input")
new DenseLayer.Builder() .addLayer(frozenBranchUnfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(3).build(),initialLayer)
.nIn(4) .addLayer(frozenBranchFrozenLayer1,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(3).nOut(4).build(),frozenBranchUnfrozenLayer0)
.nOut(4) .addLayer(frozenBranchFrozenLayer2,new DenseLayer.Builder().updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).nIn(4).nOut(2).build(),frozenBranchFrozenLayer1)
.build(), .addLayer(unfrozenLayer0,new DenseLayer.Builder().nIn(4).nOut(4).build(),initialLayer)
"input" .addLayer(unfrozenLayer1,new DenseLayer.Builder().nIn(4).nOut(2).build(),unfrozenLayer0)
) .addLayer(unfrozenBranch2,new DenseLayer.Builder().nIn(2).nOut(1).build(),unfrozenLayer1)
.addLayer(frozenBranchUnfrozenLayer0, .addVertex("merge",new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
new DenseLayer.Builder() .addLayer(frozenBranchOutput,new OutputLayer.Builder(LossFunctions.LossFunction.MSE).updater(new Sgd(0.0)).biasUpdater(new Sgd(0.0)).activation(Activation.TANH).nIn(3).nOut(1).build(),"merge")
.nIn(4)
.nOut(3)
.build(),
initialLayer
)
.addLayer(frozenBranchFrozenLayer1,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(3)
.nOut(4)
.build(),
frozenBranchUnfrozenLayer0
)
.addLayer(frozenBranchFrozenLayer2,
new DenseLayer.Builder()
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.nIn(4)
.nOut(2)
.build()
,
frozenBranchFrozenLayer1
)
.addLayer(unfrozenLayer0,
new DenseLayer.Builder()
.nIn(4)
.nOut(4)
.build(),
initialLayer
)
.addLayer(unfrozenLayer1,
new DenseLayer.Builder()
.nIn(4)
.nOut(2)
.build(),
unfrozenLayer0
)
.addLayer(unfrozenBranch2,
new DenseLayer.Builder()
.nIn(2)
.nOut(1)
.build(),
unfrozenLayer1
)
.addVertex("merge",
new MergeVertex(), frozenBranchFrozenLayer2, unfrozenBranch2)
.addLayer(frozenBranchOutput,
new OutputLayer.Builder(LossFunctions.LossFunction.MSE)
.updater(new Sgd(0.0))
.biasUpdater(new Sgd(0.0))
.activation(Activation.TANH)
.nIn(3)
.nOut(1)
.build()
,
"merge"
)
.setOutputs(frozenBranchOutput) .setOutputs(frozenBranchOutput)
.build(); .build();

View File

@ -172,8 +172,8 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
Map<String,INDArray> placeholders = new HashMap<>(); Map<String,INDArray> placeholders = new HashMap<>();
placeholders.put("input", f); placeholders.put("input", f);
placeholders.put("label", l); placeholders.put("label", l);
sd.exec(placeholders, lossMse.getVarName()); Map<String,INDArray> map = sd.output(placeholders, lossMse.getVarName(), a1.getVarName());
INDArray outSd = a1.getArr(); INDArray outSd = map.get(a1.getVarName());
INDArray outDl4j = net.output(f); INDArray outDl4j = net.output(f);
assertEquals(testName, outDl4j, outSd); assertEquals(testName, outDl4j, outSd);
@ -187,7 +187,7 @@ public class CompareTrainingImplementations extends BaseDL4JTest {
//Check score //Check score
double scoreDl4j = net.score(); double scoreDl4j = net.score();
double scoreSd = lossMse.getArr().getDouble(0) + sd.calcRegularizationScore(); double scoreSd = map.get(lossMse.getVarName()).getDouble(0) + sd.calcRegularizationScore();
assertEquals(testName, scoreDl4j, scoreSd, 1e-6); assertEquals(testName, scoreDl4j, scoreSd, 1e-6);
double lossRegScoreSD = sd.calcRegularizationScore(); double lossRegScoreSD = sd.calcRegularizationScore();

View File

@ -145,7 +145,7 @@ public class LocallyConnected1D extends SameDiffLayer {
val weightsShape = new long[] {outputSize, featureDim, nOut}; val weightsShape = new long[] {outputSize, featureDim, nOut};
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
if (hasBias) { if (hasBias) {
val biasShape = new long[] {1, nOut}; val biasShape = new long[] {nOut};
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
} }
} }
@ -200,7 +200,7 @@ public class LocallyConnected1D extends SameDiffLayer {
if (hasBias) { if (hasBias) {
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b); SDVariable biasAddedResult = sameDiff.nn().biasAdd(result, b, true);
return activation.asSameDiff("out", sameDiff, biasAddedResult); return activation.asSameDiff("out", sameDiff, biasAddedResult);
} else { } else {
return activation.asSameDiff("out", sameDiff, result); return activation.asSameDiff("out", sameDiff, result);

View File

@ -145,7 +145,7 @@ public class LocallyConnected2D extends SameDiffLayer {
val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut}; val weightsShape = new long[] {outputSize[0] * outputSize[1], featureDim, nOut};
params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape); params.addWeightParam(ConvolutionParamInitializer.WEIGHT_KEY, weightsShape);
if (hasBias) { if (hasBias) {
val biasShape = new long[] {1, nOut}; val biasShape = new long[] {nOut};
params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape); params.addBiasParam(ConvolutionParamInitializer.BIAS_KEY, biasShape);
} }
} }
@ -211,7 +211,7 @@ public class LocallyConnected2D extends SameDiffLayer {
if (hasBias) { if (hasBias) {
SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY); SDVariable b = paramTable.get(ConvolutionParamInitializer.BIAS_KEY);
SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b); SDVariable biasAddedResult = sameDiff.nn().biasAdd(permutedResult, b, true);
return activation.asSameDiff("out", sameDiff, biasAddedResult); return activation.asSameDiff("out", sameDiff, biasAddedResult);
} else { } else {
return activation.asSameDiff("out", sameDiff, permutedResult); return activation.asSameDiff("out", sameDiff, permutedResult);

View File

@ -114,7 +114,7 @@ public class MergeVertex extends BaseGraphVertex {
} }
try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){ try(MemoryWorkspace ws = workspaceMgr.notifyScopeBorrowed(ArrayType.ACTIVATIONS)){
return Nd4j.hstack(in); return Nd4j.concat(1, in);
} }
} }

View File

@ -134,6 +134,7 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray[] dLdIns; INDArray[] dLdIns;
boolean[] noClose = new boolean[getNumInputArrays()];
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
if(sameDiff == null){ if(sameDiff == null){
doInit(); doInit();
@ -167,20 +168,21 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
//Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration //Because DL4J parameters are views, and SameDiff uses DeviceLocal (which doesn't support views), we need to update the arrays on each iteration
//TODO Find a more efficient solution for this //TODO Find a more efficient solution for this
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated
for (Map.Entry<String, INDArray> e : paramTable.entrySet()) { for (Map.Entry<String, INDArray> e : paramTable.entrySet()) {
INDArray arr = e.getValue(); INDArray arr = e.getValue();
sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey())); sameDiff.assignArray(arr, sameDiff.getVariable(e.getKey()));
} }
List<String> required = new ArrayList<>(inputNames.size()); //Ensure that the input placeholder gradients are calculated required.addAll(paramTable.keySet());
for(String s : inputNames){ required.addAll(inputNames);
required.add(sameDiff.getVariable(s).gradient().getVarName());
} Map<String,INDArray> gradsMap = sameDiff.calculateGradients(phMap, required);
sameDiff.execBackwards(phMap, required);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = gradsMap.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
sdGrad.close(); //TODO optimize this
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
} }
@ -195,13 +197,18 @@ public class SameDiffGraphVertex extends BaseGraphVertex {
//Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders //Edge case with lambda vertices like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIns[j] = epsilon; dLdIns[j] = epsilon;
noClose[j] = true;
} }
} }
} }
//TODO optimize //TODO optimize
for( int i=0; i<dLdIns.length; i++ ){ for( int i=0; i<dLdIns.length; i++ ){
INDArray before = dLdIns[i];
dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]); dLdIns[i] = workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIns[i]);
if(!noClose[i]){
before.close();
}
} }
//Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere //Clear placeholders and op inputs to ensure no out-of-scope arrays are still referenced anywhere

View File

@ -110,7 +110,13 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return workspaceMgr.dup(ArrayType.ACTIVATIONS, result); INDArray ret = workspaceMgr.dup(ArrayType.ACTIVATIONS, result);
if(!result.isAttached() && result.closeable()) {
//May be attached in rare edge case - for identity, or if gradients are passed through from output to input
// unchaned, as in identity, add scalar, etc
result.close();
}
return ret;
} }
} }
@ -122,6 +128,7 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
Gradient g = new DefaultGradient(); Gradient g = new DefaultGradient();
INDArray dLdIn; INDArray dLdIn;
boolean noCloseEps = false;
try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){ try(MemoryWorkspace ws = Nd4j.getWorkspaceManager().scopeOutOfWorkspaces()){
if(sameDiff == null){ if(sameDiff == null){
doInit(); doInit();
@ -151,26 +158,25 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
} }
List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1); List<String> requiredGrads = new ArrayList<>(paramTable.size() + 1);
requiredGrads.add(sameDiff.grad(INPUT_KEY).getVarName()); requiredGrads.add(INPUT_KEY);
for(String s : paramTable.keySet()){ requiredGrads.addAll(paramTable.keySet());
requiredGrads.add(sameDiff.grad(s).getVarName());
}
sameDiff.execBackwards(phMap, requiredGrads); Map<String,INDArray> m = sameDiff.calculateGradients(phMap, requiredGrads);
for(String s : paramTable.keySet() ){ for(String s : paramTable.keySet() ){
INDArray sdGrad = sameDiff.grad(s).getArr(); INDArray sdGrad = m.get(s);
INDArray dl4jGrad = gradTable.get(s); INDArray dl4jGrad = gradTable.get(s);
dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS dl4jGrad.assign(sdGrad); //TODO OPTIMIZE THIS
g.gradientForVariable().put(s, dl4jGrad); g.gradientForVariable().put(s, dl4jGrad);
sdGrad.close();
} }
SDVariable v = sameDiff.grad(INPUT_KEY); dLdIn = m.get(INPUT_KEY);
dLdIn = v.getArr();
if(dLdIn == null && fn.getGradPlaceholderName().equals(v.getVarName())){ if(dLdIn == null && fn.getGradPlaceholderName().equals(INPUT_KEY)){
//Edge case with lambda layers like identity: SameDiff doesn't store the placeholders //Edge case with lambda layers like identity: SameDiff doesn't store the placeholders
// So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here // So, this getArr() can be trying to get placeholder from SameDiff instance, when it's available here
dLdIn = epsilon; dLdIn = epsilon;
noCloseEps = true;
} }
} }
@ -178,7 +184,12 @@ public class SameDiffLayer extends AbstractLayer<AbstractSameDiffLayer> {
sameDiff.clearPlaceholders(true); sameDiff.clearPlaceholders(true);
sameDiff.clearOpInputs(); sameDiff.clearOpInputs();
return new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS Pair<Gradient, INDArray> ret = new Pair<>(g, workspaceMgr.dup(ArrayType.ACTIVATION_GRAD, dLdIn)); //TODO OPTIMIZE THIS
if(!noCloseEps && !dLdIn.isAttached() && dLdIn.closeable()) {
//Edge case: identity etc - might just pass gradient array through unchanged
dLdIn.close();
}
return ret;
} }
/**Returns the parameters of the neural network as a flattened row vector /**Returns the parameters of the neural network as a flattened row vector

View File

@ -106,6 +106,12 @@ public struct FlatNode : IFlatbufferObject
#endif #endif
public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); } public DType[] GetOutputTypesArray() { return __p.__vector_as_array<DType>(38); }
public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public FlatArray? Scalar { get { int o = __p.__offset(40); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public string ControlDeps(int j) { int o = __p.__offset(42); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepsLength { get { int o = __p.__offset(42); return o != 0 ? __p.__vector_len(o) : 0; } }
public string VarControlDeps(int j) { int o = __p.__offset(44); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int VarControlDepsLength { get { int o = __p.__offset(44); return o != 0 ? __p.__vector_len(o) : 0; } }
public string ControlDepFor(int j) { int o = __p.__offset(46); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepForLength { get { int o = __p.__offset(46); return o != 0 ? __p.__vector_len(o) : 0; } }
public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder, public static Offset<FlatNode> CreateFlatNode(FlatBufferBuilder builder,
int id = 0, int id = 0,
@ -126,9 +132,15 @@ public struct FlatNode : IFlatbufferObject
VectorOffset outputNamesOffset = default(VectorOffset), VectorOffset outputNamesOffset = default(VectorOffset),
StringOffset opNameOffset = default(StringOffset), StringOffset opNameOffset = default(StringOffset),
VectorOffset outputTypesOffset = default(VectorOffset), VectorOffset outputTypesOffset = default(VectorOffset),
Offset<FlatArray> scalarOffset = default(Offset<FlatArray>)) { Offset<FlatArray> scalarOffset = default(Offset<FlatArray>),
builder.StartObject(19); VectorOffset controlDepsOffset = default(VectorOffset),
VectorOffset varControlDepsOffset = default(VectorOffset),
VectorOffset controlDepForOffset = default(VectorOffset)) {
builder.StartObject(22);
FlatNode.AddOpNum(builder, opNum); FlatNode.AddOpNum(builder, opNum);
FlatNode.AddControlDepFor(builder, controlDepForOffset);
FlatNode.AddVarControlDeps(builder, varControlDepsOffset);
FlatNode.AddControlDeps(builder, controlDepsOffset);
FlatNode.AddScalar(builder, scalarOffset); FlatNode.AddScalar(builder, scalarOffset);
FlatNode.AddOutputTypes(builder, outputTypesOffset); FlatNode.AddOutputTypes(builder, outputTypesOffset);
FlatNode.AddOpName(builder, opNameOffset); FlatNode.AddOpName(builder, opNameOffset);
@ -150,7 +162,7 @@ public struct FlatNode : IFlatbufferObject
return FlatNode.EndFlatNode(builder); return FlatNode.EndFlatNode(builder);
} }
public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(19); } public static void StartFlatNode(FlatBufferBuilder builder) { builder.StartObject(22); }
public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); } public static void AddId(FlatBufferBuilder builder, int id) { builder.AddInt(0, id, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); } public static void AddOpType(FlatBufferBuilder builder, OpType opType) { builder.AddSbyte(2, (sbyte)opType, 0); }
@ -200,6 +212,18 @@ public struct FlatNode : IFlatbufferObject
public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); } public static VectorOffset CreateOutputTypesVectorBlock(FlatBufferBuilder builder, DType[] data) { builder.StartVector(1, data.Length, 1); builder.Add(data); return builder.EndVector(); }
public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); } public static void StartOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(1, numElems, 1); }
public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); } public static void AddScalar(FlatBufferBuilder builder, Offset<FlatArray> scalarOffset) { builder.AddOffset(18, scalarOffset.Value, 0); }
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(19, controlDepsOffset.Value, 0); }
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddVarControlDeps(FlatBufferBuilder builder, VectorOffset varControlDepsOffset) { builder.AddOffset(20, varControlDepsOffset.Value, 0); }
public static VectorOffset CreateVarControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateVarControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddControlDepFor(FlatBufferBuilder builder, VectorOffset controlDepForOffset) { builder.AddOffset(21, controlDepForOffset.Value, 0); }
public static VectorOffset CreateControlDepForVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepForVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) { public static Offset<FlatNode> EndFlatNode(FlatBufferBuilder builder) {
int o = builder.EndObject(); int o = builder.EndObject();
return new Offset<FlatNode>(o); return new Offset<FlatNode>(o);

View File

@ -66,6 +66,12 @@ public final class FlatNode extends Table {
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
public FlatArray scalar() { return scalar(new FlatArray()); } public FlatArray scalar() { return scalar(new FlatArray()); }
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatNode(FlatBufferBuilder builder, public static int createFlatNode(FlatBufferBuilder builder,
int id, int id,
@ -86,9 +92,15 @@ public final class FlatNode extends Table {
int outputNamesOffset, int outputNamesOffset,
int opNameOffset, int opNameOffset,
int outputTypesOffset, int outputTypesOffset,
int scalarOffset) { int scalarOffset,
builder.startObject(19); int controlDepsOffset,
int varControlDepsOffset,
int controlDepForOffset) {
builder.startObject(22);
FlatNode.addOpNum(builder, opNum); FlatNode.addOpNum(builder, opNum);
FlatNode.addControlDepFor(builder, controlDepForOffset);
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
FlatNode.addControlDeps(builder, controlDepsOffset);
FlatNode.addScalar(builder, scalarOffset); FlatNode.addScalar(builder, scalarOffset);
FlatNode.addOutputTypes(builder, outputTypesOffset); FlatNode.addOutputTypes(builder, outputTypesOffset);
FlatNode.addOpName(builder, opNameOffset); FlatNode.addOpName(builder, opNameOffset);
@ -110,7 +122,7 @@ public final class FlatNode extends Table {
return FlatNode.endFlatNode(builder); return FlatNode.endFlatNode(builder);
} }
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); } public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
@ -150,6 +162,15 @@ public final class FlatNode extends Table {
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endFlatNode(FlatBufferBuilder builder) { public static int endFlatNode(FlatBufferBuilder builder) {
int o = builder.endObject(); int o = builder.endObject();
return o; return o;

View File

@ -294,7 +294,52 @@ class FlatNode(object):
return obj return obj
return None return None
def FlatNodeStart(builder): builder.StartObject(19) # FlatNode
def ControlDeps(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatNode
def ControlDepsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(42))
if o != 0:
return self._tab.VectorLen(o)
return 0
# FlatNode
def VarControlDeps(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatNode
def VarControlDepsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(44))
if o != 0:
return self._tab.VectorLen(o)
return 0
# FlatNode
def ControlDepFor(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatNode
def ControlDepForLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(46))
if o != 0:
return self._tab.VectorLen(o)
return 0
def FlatNodeStart(builder): builder.StartObject(22)
def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0) def FlatNodeAddId(builder, id): builder.PrependInt32Slot(0, id, 0)
def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatNodeAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0) def FlatNodeAddOpType(builder, opType): builder.PrependInt8Slot(2, opType, 0)
@ -324,4 +369,10 @@ def FlatNodeAddOpName(builder, opName): builder.PrependUOffsetTRelativeSlot(16,
def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0) def FlatNodeAddOutputTypes(builder, outputTypes): builder.PrependUOffsetTRelativeSlot(17, flatbuffers.number_types.UOffsetTFlags.py_type(outputTypes), 0)
def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1) def FlatNodeStartOutputTypesVector(builder, numElems): return builder.StartVector(1, numElems, 1)
def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0) def FlatNodeAddScalar(builder, scalar): builder.PrependUOffsetTRelativeSlot(18, flatbuffers.number_types.UOffsetTFlags.py_type(scalar), 0)
def FlatNodeAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(19, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
def FlatNodeStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddVarControlDeps(builder, varControlDeps): builder.PrependUOffsetTRelativeSlot(20, flatbuffers.number_types.UOffsetTFlags.py_type(varControlDeps), 0)
def FlatNodeStartVarControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeAddControlDepFor(builder, controlDepFor): builder.PrependUOffsetTRelativeSlot(21, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepFor), 0)
def FlatNodeStartControlDepForVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatNodeEnd(builder): return builder.EndObject() def FlatNodeEnd(builder): return builder.EndObject()

View File

@ -37,6 +37,12 @@ public struct FlatVariable : IFlatbufferObject
public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } } public FlatArray? Ndarray { get { int o = __p.__offset(12); return o != 0 ? (FlatArray?)(new FlatArray()).__assign(__p.__indirect(o + __p.bb_pos), __p.bb) : null; } }
public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } } public int Device { get { int o = __p.__offset(14); return o != 0 ? __p.bb.GetInt(o + __p.bb_pos) : (int)0; } }
public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } } public VarType Variabletype { get { int o = __p.__offset(16); return o != 0 ? (VarType)__p.bb.GetSbyte(o + __p.bb_pos) : VarType.VARIABLE; } }
public string ControlDeps(int j) { int o = __p.__offset(18); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepsLength { get { int o = __p.__offset(18); return o != 0 ? __p.__vector_len(o) : 0; } }
public string ControlDepForOp(int j) { int o = __p.__offset(20); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepForOpLength { get { int o = __p.__offset(20); return o != 0 ? __p.__vector_len(o) : 0; } }
public string ControlDepsForVar(int j) { int o = __p.__offset(22); return o != 0 ? __p.__string(__p.__vector(o) + j * 4) : null; }
public int ControlDepsForVarLength { get { int o = __p.__offset(22); return o != 0 ? __p.__vector_len(o) : 0; } }
public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder, public static Offset<FlatVariable> CreateFlatVariable(FlatBufferBuilder builder,
Offset<IntPair> idOffset = default(Offset<IntPair>), Offset<IntPair> idOffset = default(Offset<IntPair>),
@ -45,8 +51,14 @@ public struct FlatVariable : IFlatbufferObject
VectorOffset shapeOffset = default(VectorOffset), VectorOffset shapeOffset = default(VectorOffset),
Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>), Offset<FlatArray> ndarrayOffset = default(Offset<FlatArray>),
int device = 0, int device = 0,
VarType variabletype = VarType.VARIABLE) { VarType variabletype = VarType.VARIABLE,
builder.StartObject(7); VectorOffset controlDepsOffset = default(VectorOffset),
VectorOffset controlDepForOpOffset = default(VectorOffset),
VectorOffset controlDepsForVarOffset = default(VectorOffset)) {
builder.StartObject(10);
FlatVariable.AddControlDepsForVar(builder, controlDepsForVarOffset);
FlatVariable.AddControlDepForOp(builder, controlDepForOpOffset);
FlatVariable.AddControlDeps(builder, controlDepsOffset);
FlatVariable.AddDevice(builder, device); FlatVariable.AddDevice(builder, device);
FlatVariable.AddNdarray(builder, ndarrayOffset); FlatVariable.AddNdarray(builder, ndarrayOffset);
FlatVariable.AddShape(builder, shapeOffset); FlatVariable.AddShape(builder, shapeOffset);
@ -57,7 +69,7 @@ public struct FlatVariable : IFlatbufferObject
return FlatVariable.EndFlatVariable(builder); return FlatVariable.EndFlatVariable(builder);
} }
public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(7); } public static void StartFlatVariable(FlatBufferBuilder builder) { builder.StartObject(10); }
public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); } public static void AddId(FlatBufferBuilder builder, Offset<IntPair> idOffset) { builder.AddOffset(0, idOffset.Value, 0); }
public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); } public static void AddName(FlatBufferBuilder builder, StringOffset nameOffset) { builder.AddOffset(1, nameOffset.Value, 0); }
public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); } public static void AddDtype(FlatBufferBuilder builder, DType dtype) { builder.AddSbyte(2, (sbyte)dtype, 0); }
@ -68,6 +80,18 @@ public struct FlatVariable : IFlatbufferObject
public static void AddNdarray(FlatBufferBuilder builder, Offset<FlatArray> ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); } public static void AddNdarray(FlatBufferBuilder builder, Offset<FlatArray> ndarrayOffset) { builder.AddOffset(4, ndarrayOffset.Value, 0); }
public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); } public static void AddDevice(FlatBufferBuilder builder, int device) { builder.AddInt(5, device, 0); }
public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); } public static void AddVariabletype(FlatBufferBuilder builder, VarType variabletype) { builder.AddSbyte(6, (sbyte)variabletype, 0); }
public static void AddControlDeps(FlatBufferBuilder builder, VectorOffset controlDepsOffset) { builder.AddOffset(7, controlDepsOffset.Value, 0); }
public static VectorOffset CreateControlDepsVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepsVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddControlDepForOp(FlatBufferBuilder builder, VectorOffset controlDepForOpOffset) { builder.AddOffset(8, controlDepForOpOffset.Value, 0); }
public static VectorOffset CreateControlDepForOpVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepForOpVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static void AddControlDepsForVar(FlatBufferBuilder builder, VectorOffset controlDepsForVarOffset) { builder.AddOffset(9, controlDepsForVarOffset.Value, 0); }
public static VectorOffset CreateControlDepsForVarVector(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); for (int i = data.Length - 1; i >= 0; i--) builder.AddOffset(data[i].Value); return builder.EndVector(); }
public static VectorOffset CreateControlDepsForVarVectorBlock(FlatBufferBuilder builder, StringOffset[] data) { builder.StartVector(4, data.Length, 4); builder.Add(data); return builder.EndVector(); }
public static void StartControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.StartVector(4, numElems, 4); }
public static Offset<FlatVariable> EndFlatVariable(FlatBufferBuilder builder) { public static Offset<FlatVariable> EndFlatVariable(FlatBufferBuilder builder) {
int o = builder.EndObject(); int o = builder.EndObject();
return new Offset<FlatVariable>(o); return new Offset<FlatVariable>(o);

View File

@ -28,6 +28,12 @@ public final class FlatVariable extends Table {
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatVariable(FlatBufferBuilder builder, public static int createFlatVariable(FlatBufferBuilder builder,
int idOffset, int idOffset,
@ -36,8 +42,14 @@ public final class FlatVariable extends Table {
int shapeOffset, int shapeOffset,
int ndarrayOffset, int ndarrayOffset,
int device, int device,
byte variabletype) { byte variabletype,
builder.startObject(7); int controlDepsOffset,
int controlDepForOpOffset,
int controlDepsForVarOffset) {
builder.startObject(10);
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
FlatVariable.addControlDeps(builder, controlDepsOffset);
FlatVariable.addDevice(builder, device); FlatVariable.addDevice(builder, device);
FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addNdarray(builder, ndarrayOffset);
FlatVariable.addShape(builder, shapeOffset); FlatVariable.addShape(builder, shapeOffset);
@ -48,7 +60,7 @@ public final class FlatVariable extends Table {
return FlatVariable.endFlatVariable(builder); return FlatVariable.endFlatVariable(builder);
} }
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
@ -58,6 +70,15 @@ public final class FlatVariable extends Table {
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endFlatVariable(FlatBufferBuilder builder) { public static int endFlatVariable(FlatBufferBuilder builder) {
int o = builder.endObject(); int o = builder.endObject();
return o; return o;

View File

@ -90,7 +90,52 @@ class FlatVariable(object):
return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos) return self._tab.Get(flatbuffers.number_types.Int8Flags, o + self._tab.Pos)
return 0 return 0
def FlatVariableStart(builder): builder.StartObject(7) # FlatVariable
def ControlDeps(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatVariable
def ControlDepsLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(18))
if o != 0:
return self._tab.VectorLen(o)
return 0
# FlatVariable
def ControlDepForOp(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatVariable
def ControlDepForOpLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(20))
if o != 0:
return self._tab.VectorLen(o)
return 0
# FlatVariable
def ControlDepsForVar(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
a = self._tab.Vector(o)
return self._tab.String(a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return ""
# FlatVariable
def ControlDepsForVarLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(22))
if o != 0:
return self._tab.VectorLen(o)
return 0
def FlatVariableStart(builder): builder.StartObject(10)
def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0) def FlatVariableAddId(builder, id): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(id), 0)
def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0) def FlatVariableAddName(builder, name): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(name), 0)
def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0) def FlatVariableAddDtype(builder, dtype): builder.PrependInt8Slot(2, dtype, 0)
@ -99,4 +144,10 @@ def FlatVariableStartShapeVector(builder, numElems): return builder.StartVector(
def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0) def FlatVariableAddNdarray(builder, ndarray): builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(ndarray), 0)
def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0) def FlatVariableAddDevice(builder, device): builder.PrependInt32Slot(5, device, 0)
def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0) def FlatVariableAddVariabletype(builder, variabletype): builder.PrependInt8Slot(6, variabletype, 0)
def FlatVariableAddControlDeps(builder, controlDeps): builder.PrependUOffsetTRelativeSlot(7, flatbuffers.number_types.UOffsetTFlags.py_type(controlDeps), 0)
def FlatVariableStartControlDepsVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatVariableAddControlDepForOp(builder, controlDepForOp): builder.PrependUOffsetTRelativeSlot(8, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepForOp), 0)
def FlatVariableStartControlDepForOpVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatVariableAddControlDepsForVar(builder, controlDepsForVar): builder.PrependUOffsetTRelativeSlot(9, flatbuffers.number_types.UOffsetTFlags.py_type(controlDepsForVar), 0)
def FlatVariableStartControlDepsForVarVector(builder, numElems): return builder.StartVector(4, numElems, 4)
def FlatVariableEnd(builder): return builder.EndObject() def FlatVariableEnd(builder): return builder.EndObject()

View File

@ -35,7 +35,10 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_OUTPUTNAMES = 34, VT_OUTPUTNAMES = 34,
VT_OPNAME = 36, VT_OPNAME = 36,
VT_OUTPUTTYPES = 38, VT_OUTPUTTYPES = 38,
VT_SCALAR = 40 VT_SCALAR = 40,
VT_CONTROLDEPS = 42,
VT_VARCONTROLDEPS = 44,
VT_CONTROLDEPFOR = 46
}; };
int32_t id() const { int32_t id() const {
return GetField<int32_t>(VT_ID, 0); return GetField<int32_t>(VT_ID, 0);
@ -94,6 +97,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FlatArray *scalar() const { const FlatArray *scalar() const {
return GetPointer<const FlatArray *>(VT_SCALAR); return GetPointer<const FlatArray *>(VT_SCALAR);
} }
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_VARCONTROLDEPS);
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOR);
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyField<int32_t>(verifier, VT_ID) && VerifyField<int32_t>(verifier, VT_ID) &&
@ -132,6 +144,15 @@ struct FlatNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyVector(outputTypes()) && verifier.VerifyVector(outputTypes()) &&
VerifyOffset(verifier, VT_SCALAR) && VerifyOffset(verifier, VT_SCALAR) &&
verifier.VerifyTable(scalar()) && verifier.VerifyTable(scalar()) &&
VerifyOffset(verifier, VT_CONTROLDEPS) &&
verifier.VerifyVector(controlDeps()) &&
verifier.VerifyVectorOfStrings(controlDeps()) &&
VerifyOffset(verifier, VT_VARCONTROLDEPS) &&
verifier.VerifyVector(varControlDeps()) &&
verifier.VerifyVectorOfStrings(varControlDeps()) &&
VerifyOffset(verifier, VT_CONTROLDEPFOR) &&
verifier.VerifyVector(controlDepFor()) &&
verifier.VerifyVectorOfStrings(controlDepFor()) &&
verifier.EndTable(); verifier.EndTable();
} }
}; };
@ -196,6 +217,15 @@ struct FlatNodeBuilder {
void add_scalar(flatbuffers::Offset<FlatArray> scalar) { void add_scalar(flatbuffers::Offset<FlatArray> scalar) {
fbb_.AddOffset(FlatNode::VT_SCALAR, scalar); fbb_.AddOffset(FlatNode::VT_SCALAR, scalar);
} }
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
fbb_.AddOffset(FlatNode::VT_CONTROLDEPS, controlDeps);
}
void add_varControlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps) {
fbb_.AddOffset(FlatNode::VT_VARCONTROLDEPS, varControlDeps);
}
void add_controlDepFor(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor) {
fbb_.AddOffset(FlatNode::VT_CONTROLDEPFOR, controlDepFor);
}
explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit FlatNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -228,9 +258,15 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNode(
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputNames = 0, flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> outputNames = 0,
flatbuffers::Offset<flatbuffers::String> opName = 0, flatbuffers::Offset<flatbuffers::String> opName = 0,
flatbuffers::Offset<flatbuffers::Vector<int8_t>> outputTypes = 0, flatbuffers::Offset<flatbuffers::Vector<int8_t>> outputTypes = 0,
flatbuffers::Offset<FlatArray> scalar = 0) { flatbuffers::Offset<FlatArray> scalar = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> varControlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepFor = 0) {
FlatNodeBuilder builder_(_fbb); FlatNodeBuilder builder_(_fbb);
builder_.add_opNum(opNum); builder_.add_opNum(opNum);
builder_.add_controlDepFor(controlDepFor);
builder_.add_varControlDeps(varControlDeps);
builder_.add_controlDeps(controlDeps);
builder_.add_scalar(scalar); builder_.add_scalar(scalar);
builder_.add_outputTypes(outputTypes); builder_.add_outputTypes(outputTypes);
builder_.add_opName(opName); builder_.add_opName(opName);
@ -272,7 +308,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputNames = nullptr, const std::vector<flatbuffers::Offset<flatbuffers::String>> *outputNames = nullptr,
const char *opName = nullptr, const char *opName = nullptr,
const std::vector<int8_t> *outputTypes = nullptr, const std::vector<int8_t> *outputTypes = nullptr,
flatbuffers::Offset<FlatArray> scalar = 0) { flatbuffers::Offset<FlatArray> scalar = 0,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *varControlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepFor = nullptr) {
return nd4j::graph::CreateFlatNode( return nd4j::graph::CreateFlatNode(
_fbb, _fbb,
id, id,
@ -293,7 +332,10 @@ inline flatbuffers::Offset<FlatNode> CreateFlatNodeDirect(
outputNames ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputNames) : 0, outputNames ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*outputNames) : 0,
opName ? _fbb.CreateString(opName) : 0, opName ? _fbb.CreateString(opName) : 0,
outputTypes ? _fbb.CreateVector<int8_t>(*outputTypes) : 0, outputTypes ? _fbb.CreateVector<int8_t>(*outputTypes) : 0,
scalar); scalar,
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
varControlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*varControlDeps) : 0,
controlDepFor ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepFor) : 0);
} }
inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) { inline const nd4j::graph::FlatNode *GetFlatNode(const void *buf) {

View File

@ -344,11 +344,65 @@ nd4j.graph.FlatNode.prototype.scalar = function(obj) {
return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null; return offset ? (obj || new nd4j.graph.FlatArray).__init(this.bb.__indirect(this.bb_pos + offset), this.bb) : null;
}; };
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatNode.prototype.controlDeps = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 42);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatNode.prototype.controlDepsLength = function() {
var offset = this.bb.__offset(this.bb_pos, 42);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatNode.prototype.varControlDeps = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 44);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatNode.prototype.varControlDepsLength = function() {
var offset = this.bb.__offset(this.bb_pos, 44);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatNode.prototype.controlDepFor = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 46);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatNode.prototype.controlDepForLength = function() {
var offset = this.bb.__offset(this.bb_pos, 46);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
*/ */
nd4j.graph.FlatNode.startFlatNode = function(builder) { nd4j.graph.FlatNode.startFlatNode = function(builder) {
builder.startObject(19); builder.startObject(22);
}; };
/** /**
@ -713,6 +767,93 @@ nd4j.graph.FlatNode.addScalar = function(builder, scalarOffset) {
builder.addFieldOffset(18, scalarOffset, 0); builder.addFieldOffset(18, scalarOffset, 0);
}; };
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} controlDepsOffset
*/
nd4j.graph.FlatNode.addControlDeps = function(builder, controlDepsOffset) {
builder.addFieldOffset(19, controlDepsOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createControlDepsVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatNode.startControlDepsVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} varControlDepsOffset
*/
nd4j.graph.FlatNode.addVarControlDeps = function(builder, varControlDepsOffset) {
builder.addFieldOffset(20, varControlDepsOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createVarControlDepsVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatNode.startVarControlDepsVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} controlDepForOffset
*/
nd4j.graph.FlatNode.addControlDepFor = function(builder, controlDepForOffset) {
builder.addFieldOffset(21, controlDepForOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatNode.createControlDepForVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatNode.startControlDepForVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @returns {flatbuffers.Offset} * @returns {flatbuffers.Offset}

View File

@ -57,7 +57,10 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VT_SHAPE = 10, VT_SHAPE = 10,
VT_NDARRAY = 12, VT_NDARRAY = 12,
VT_DEVICE = 14, VT_DEVICE = 14,
VT_VARIABLETYPE = 16 VT_VARIABLETYPE = 16,
VT_CONTROLDEPS = 18,
VT_CONTROLDEPFOROP = 20,
VT_CONTROLDEPSFORVAR = 22
}; };
const IntPair *id() const { const IntPair *id() const {
return GetPointer<const IntPair *>(VT_ID); return GetPointer<const IntPair *>(VT_ID);
@ -80,6 +83,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
VarType variabletype() const { VarType variabletype() const {
return static_cast<VarType>(GetField<int8_t>(VT_VARIABLETYPE, 0)); return static_cast<VarType>(GetField<int8_t>(VT_VARIABLETYPE, 0));
} }
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPS);
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPFOROP);
}
const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar() const {
return GetPointer<const flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>> *>(VT_CONTROLDEPSFORVAR);
}
bool Verify(flatbuffers::Verifier &verifier) const { bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) && return VerifyTableStart(verifier) &&
VerifyOffset(verifier, VT_ID) && VerifyOffset(verifier, VT_ID) &&
@ -93,6 +105,15 @@ struct FlatVariable FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
verifier.VerifyTable(ndarray()) && verifier.VerifyTable(ndarray()) &&
VerifyField<int32_t>(verifier, VT_DEVICE) && VerifyField<int32_t>(verifier, VT_DEVICE) &&
VerifyField<int8_t>(verifier, VT_VARIABLETYPE) && VerifyField<int8_t>(verifier, VT_VARIABLETYPE) &&
VerifyOffset(verifier, VT_CONTROLDEPS) &&
verifier.VerifyVector(controlDeps()) &&
verifier.VerifyVectorOfStrings(controlDeps()) &&
VerifyOffset(verifier, VT_CONTROLDEPFOROP) &&
verifier.VerifyVector(controlDepForOp()) &&
verifier.VerifyVectorOfStrings(controlDepForOp()) &&
VerifyOffset(verifier, VT_CONTROLDEPSFORVAR) &&
verifier.VerifyVector(controlDepsForVar()) &&
verifier.VerifyVectorOfStrings(controlDepsForVar()) &&
verifier.EndTable(); verifier.EndTable();
} }
}; };
@ -121,6 +142,15 @@ struct FlatVariableBuilder {
void add_variabletype(VarType variabletype) { void add_variabletype(VarType variabletype) {
fbb_.AddElement<int8_t>(FlatVariable::VT_VARIABLETYPE, static_cast<int8_t>(variabletype), 0); fbb_.AddElement<int8_t>(FlatVariable::VT_VARIABLETYPE, static_cast<int8_t>(variabletype), 0);
} }
void add_controlDeps(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps) {
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPS, controlDeps);
}
void add_controlDepForOp(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp) {
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPFOROP, controlDepForOp);
}
void add_controlDepsForVar(flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar) {
fbb_.AddOffset(FlatVariable::VT_CONTROLDEPSFORVAR, controlDepsForVar);
}
explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb) explicit FlatVariableBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) { : fbb_(_fbb) {
start_ = fbb_.StartTable(); start_ = fbb_.StartTable();
@ -141,8 +171,14 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariable(
flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0, flatbuffers::Offset<flatbuffers::Vector<int64_t>> shape = 0,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,
VarType variabletype = VarType_VARIABLE) { VarType variabletype = VarType_VARIABLE,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDeps = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepForOp = 0,
flatbuffers::Offset<flatbuffers::Vector<flatbuffers::Offset<flatbuffers::String>>> controlDepsForVar = 0) {
FlatVariableBuilder builder_(_fbb); FlatVariableBuilder builder_(_fbb);
builder_.add_controlDepsForVar(controlDepsForVar);
builder_.add_controlDepForOp(controlDepForOp);
builder_.add_controlDeps(controlDeps);
builder_.add_device(device); builder_.add_device(device);
builder_.add_ndarray(ndarray); builder_.add_ndarray(ndarray);
builder_.add_shape(shape); builder_.add_shape(shape);
@ -161,7 +197,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
const std::vector<int64_t> *shape = nullptr, const std::vector<int64_t> *shape = nullptr,
flatbuffers::Offset<FlatArray> ndarray = 0, flatbuffers::Offset<FlatArray> ndarray = 0,
int32_t device = 0, int32_t device = 0,
VarType variabletype = VarType_VARIABLE) { VarType variabletype = VarType_VARIABLE,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDeps = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepForOp = nullptr,
const std::vector<flatbuffers::Offset<flatbuffers::String>> *controlDepsForVar = nullptr) {
return nd4j::graph::CreateFlatVariable( return nd4j::graph::CreateFlatVariable(
_fbb, _fbb,
id, id,
@ -170,7 +209,10 @@ inline flatbuffers::Offset<FlatVariable> CreateFlatVariableDirect(
shape ? _fbb.CreateVector<int64_t>(*shape) : 0, shape ? _fbb.CreateVector<int64_t>(*shape) : 0,
ndarray, ndarray,
device, device,
variabletype); variabletype,
controlDeps ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDeps) : 0,
controlDepForOp ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepForOp) : 0,
controlDepsForVar ? _fbb.CreateVector<flatbuffers::Offset<flatbuffers::String>>(*controlDepsForVar) : 0);
} }
inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) { inline const nd4j::graph::FlatVariable *GetFlatVariable(const void *buf) {

View File

@ -125,11 +125,65 @@ nd4j.graph.FlatVariable.prototype.variabletype = function() {
return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE; return offset ? /** @type {nd4j.graph.VarType} */ (this.bb.readInt8(this.bb_pos + offset)) : nd4j.graph.VarType.VARIABLE;
}; };
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatVariable.prototype.controlDeps = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 18);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatVariable.prototype.controlDepsLength = function() {
var offset = this.bb.__offset(this.bb_pos, 18);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatVariable.prototype.controlDepForOp = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 20);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatVariable.prototype.controlDepForOpLength = function() {
var offset = this.bb.__offset(this.bb_pos, 20);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/**
* @param {number} index
* @param {flatbuffers.Encoding=} optionalEncoding
* @returns {string|Uint8Array}
*/
nd4j.graph.FlatVariable.prototype.controlDepsForVar = function(index, optionalEncoding) {
var offset = this.bb.__offset(this.bb_pos, 22);
return offset ? this.bb.__string(this.bb.__vector(this.bb_pos + offset) + index * 4, optionalEncoding) : null;
};
/**
* @returns {number}
*/
nd4j.graph.FlatVariable.prototype.controlDepsForVarLength = function() {
var offset = this.bb.__offset(this.bb_pos, 22);
return offset ? this.bb.__vector_len(this.bb_pos + offset) : 0;
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
*/ */
nd4j.graph.FlatVariable.startFlatVariable = function(builder) { nd4j.graph.FlatVariable.startFlatVariable = function(builder) {
builder.startObject(7); builder.startObject(10);
}; };
/** /**
@ -209,6 +263,93 @@ nd4j.graph.FlatVariable.addVariabletype = function(builder, variabletype) {
builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE); builder.addFieldInt8(6, variabletype, nd4j.graph.VarType.VARIABLE);
}; };
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} controlDepsOffset
*/
nd4j.graph.FlatVariable.addControlDeps = function(builder, controlDepsOffset) {
builder.addFieldOffset(7, controlDepsOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatVariable.createControlDepsVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatVariable.startControlDepsVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} controlDepForOpOffset
*/
nd4j.graph.FlatVariable.addControlDepForOp = function(builder, controlDepForOpOffset) {
builder.addFieldOffset(8, controlDepForOpOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatVariable.createControlDepForOpVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatVariable.startControlDepForOpVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/**
* @param {flatbuffers.Builder} builder
* @param {flatbuffers.Offset} controlDepsForVarOffset
*/
nd4j.graph.FlatVariable.addControlDepsForVar = function(builder, controlDepsForVarOffset) {
builder.addFieldOffset(9, controlDepsForVarOffset, 0);
};
/**
* @param {flatbuffers.Builder} builder
* @param {Array.<flatbuffers.Offset>} data
* @returns {flatbuffers.Offset}
*/
nd4j.graph.FlatVariable.createControlDepsForVarVector = function(builder, data) {
builder.startVector(4, data.length, 4);
for (var i = data.length - 1; i >= 0; i--) {
builder.addOffset(data[i]);
}
return builder.endVector();
};
/**
* @param {flatbuffers.Builder} builder
* @param {number} numElems
*/
nd4j.graph.FlatVariable.startControlDepsForVarVector = function(builder, numElems) {
builder.startVector(4, numElems, 4);
};
/** /**
* @param {flatbuffers.Builder} builder * @param {flatbuffers.Builder} builder
* @returns {flatbuffers.Offset} * @returns {flatbuffers.Offset}

View File

@ -52,6 +52,12 @@ table FlatNode {
//Scalar value - used for scalar ops. Should be single value only. //Scalar value - used for scalar ops. Should be single value only.
scalar:FlatArray; scalar:FlatArray;
//Control dependencies
controlDeps:[string];
varControlDeps:[string];
controlDepFor:[string];
} }
root_type FlatNode; root_type FlatNode;

View File

@ -37,6 +37,10 @@ table FlatVariable {
device:int; // default is -1, which means _auto_ device:int; // default is -1, which means _auto_
variabletype:VarType; variabletype:VarType;
controlDeps:[string];
controlDepForOp:[string];
controlDepsForVar:[string];
} }
root_type FlatVariable; root_type FlatVariable;

View File

@ -659,7 +659,8 @@ public abstract class DifferentialFunction {
if(sameDiff == null) if(sameDiff == null)
this.ownName = UUID.randomUUID().toString(); this.ownName = UUID.randomUUID().toString();
else { else {
this.ownName = sameDiff.getOpName(opName()); String n = sameDiff.getOpName(opName());
this.ownName = n;
} }
if(sameDiff != null) if(sameDiff != null)
@ -696,30 +697,11 @@ public abstract class DifferentialFunction {
} }
@JsonIgnore @JsonIgnore
private INDArray getX() { public INDArray getInputArgument(int index){
INDArray ret = sameDiff.getArrForVarName(args()[0].getVarName()); //Subclasses should implement this
return ret; throw new UnsupportedOperationException("Not implemented");
} }
@JsonIgnore
private INDArray getY() {
if(args().length > 1) {
INDArray ret = sameDiff.getArrForVarName(args()[1].getVarName());
return ret;
}
return null;
}
@JsonIgnore
private INDArray getZ() {
if(isInPlace())
return getX();
SDVariable opId = outputVariables()[0];
INDArray ret = opId.getArr();
return ret;
}
/** /**
@ -860,4 +842,8 @@ public abstract class DifferentialFunction {
public int getNumOutputs(){return -1;} public int getNumOutputs(){return -1;}
/**
* Clear the input and output INDArrays, if any are set
*/
public abstract void clearArrays();
} }

View File

@ -982,8 +982,8 @@ public class DifferentialFunctionFactory {
return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable(); return new CumProdBp(sameDiff(), in, grad, exclusive, reverse, axis).outputVariable();
} }
public SDVariable biasAdd(SDVariable input, SDVariable bias) { public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
return new BiasAdd(sameDiff(), input, bias).outputVariable(); return new BiasAdd(sameDiff(), input, bias, nchw).outputVariable();
} }
public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) { public SDVariable[] biasAddBp(SDVariable input, SDVariable bias, SDVariable grad) {

View File

@ -24,6 +24,7 @@ import lombok.Getter;
import org.nd4j.autodiff.listeners.Listener; import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.IMetric; import org.nd4j.evaluation.IMetric;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator; import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
@ -319,6 +320,7 @@ public class History {
* Gets the training evaluations ran during the last epoch * Gets the training evaluations ran during the last epoch
*/ */
public EvaluationRecord finalTrainingEvaluations(){ public EvaluationRecord finalTrainingEvaluations(){
Preconditions.checkState(!trainingHistory.isEmpty(), "Cannot get final training evaluation - history is empty");
return trainingHistory.get(trainingHistory.size() - 1); return trainingHistory.get(trainingHistory.size() - 1);
} }
@ -326,6 +328,7 @@ public class History {
* Gets the validation evaluations ran during the last epoch * Gets the validation evaluations ran during the last epoch
*/ */
public EvaluationRecord finalValidationEvaluations(){ public EvaluationRecord finalValidationEvaluations(){
Preconditions.checkState(!validationHistory.isEmpty(), "Cannot get final validation evaluation - history is empty");
return validationHistory.get(validationHistory.size() - 1); return validationHistory.get(validationHistory.size() - 1);
} }

View File

@ -16,34 +16,23 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import java.util.Objects;
import lombok.*; import lombok.*;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable; import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.blas.params.MMulTranspose; import org.nd4j.linalg.api.blas.params.MMulTranspose;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.*;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.weightinit.WeightInitScheme; import org.nd4j.weightinit.WeightInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.io.Serializable; import java.io.Serializable;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Objects;
/** /**
* *
@ -167,6 +156,10 @@ public class SDVariable implements Serializable {
if(sameDiff.arrayAlreadyExistsForVarName(getVarName())) if(sameDiff.arrayAlreadyExistsForVarName(getVarName()))
return sameDiff.getArrForVarName(getVarName()); return sameDiff.getArrForVarName(getVarName());
if(variableType == VariableType.ARRAY){
throw new UnsupportedOperationException("Cannot get array for ARRAY type SDVariable - use SDVariable.exec or SameDiff.output instead");
}
//initialize value if it's actually a scalar constant (zero or 1 typically...) //initialize value if it's actually a scalar constant (zero or 1 typically...)
if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){ if(variableType == VariableType.VARIABLE && weightInitScheme != null && shape != null){
INDArray arr = weightInitScheme.create(dataType, shape); INDArray arr = weightInitScheme.create(dataType, shape);
@ -211,8 +204,8 @@ public class SDVariable implements Serializable {
* created automatically when training is performed. * created automatically when training is performed.
*/ */
public SDVariable getGradient() { public SDVariable getGradient() {
Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s variable \"%s\": only floating" + Preconditions.checkState(dataType().isFPType(), "Cannot get gradient of %s datatype variable \"%s\": only floating" +
" point variables have gradients", getVarName(), dataType()); " point variables have gradients", dataType(), getVarName());
return sameDiff.getGradForVariable(getVarName()); return sameDiff.getGradForVariable(getVarName());
} }
@ -230,7 +223,7 @@ public class SDVariable implements Serializable {
} }
long[] initialShape = sameDiff.getShapeForVarName(getVarName()); long[] initialShape = sameDiff.getShapeForVarName(getVarName());
if(initialShape == null) { if(initialShape == null && variableType != VariableType.ARRAY) {
val arr = getArr(); val arr = getArr();
if(arr != null) if(arr != null)
return arr.shape(); return arr.shape();
@ -254,7 +247,7 @@ public class SDVariable implements Serializable {
public DataType dataType() { public DataType dataType() {
if(this.dataType == null){ if(this.dataType == null){
//Try to infer datatype instead of returning null //Try to infer datatype instead of returning null
if(getArr() != null){ if(variableType != VariableType.ARRAY && getArr() != null){
this.dataType = getArr().dataType(); this.dataType = getArr().dataType();
} }
} }
@ -1518,26 +1511,59 @@ public class SDVariable implements Serializable {
/** /**
* Add a control dependency for this variable on the specified variable.<br> * Add a control dependency for this variable on the specified variable.<br>
* Control depnedencies can be used to enforce the execution order. * Control dependencies can be used to enforce the execution order.
* For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even * For example, if a control dependency X->Y exists, then Y will only be executed after X is executed - even
* if Y wouldn't normally depend on the result/values of X. * if Y wouldn't normally depend on the result/values of X.
* *
* @param controlDependency Control dependency to add for this variable * @param controlDependency Control dependency to add for this variable
*/ */
public void addControlDependency(SDVariable controlDependency){ public void addControlDependency(SDVariable controlDependency){
String cdN = controlDependency.getVarName(); Variable vThis = sameDiff.getVariables().get(getVarName());
String n = this.getVarName(); Variable vCD = sameDiff.getVariables().get(controlDependency.getVarName());
Variable v = sameDiff.getVariables().get(n);
if(v.getControlDeps() == null)
v.setControlDeps(new ArrayList<String>());
if(!v.getControlDeps().contains(cdN))
v.getControlDeps().add(cdN);
Variable v2 = sameDiff.getVariables().get(cdN); //If possible: add control dependency on ops
if(v2.getControlDepsForVar() == null) if(vThis.getOutputOfOp() != null && vCD.getOutputOfOp() != null ){
v2.setControlDepsForVar(new ArrayList<String>()); //Op -> Op case
if(!v2.getControlDepsForVar().contains(n)) SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
v2.getControlDepsForVar().add(n); SameDiffOp oCD = sameDiff.getOps().get(vCD.getOutputOfOp());
if(oThis.getControlDeps() == null)
oThis.setControlDeps(new ArrayList<String>());
if(!oThis.getControlDeps().contains(oCD.getName()))
oThis.getControlDeps().add(oCD.getName());
if(oCD.getControlDepFor() == null)
oCD.setControlDepFor(new ArrayList<String>());
if(!oCD.getControlDepFor().contains(oThis.getName()))
oCD.getControlDepFor().add(oThis.getName());
} else {
if(vThis.getOutputOfOp() != null){
//const/ph -> op case
SameDiffOp oThis = sameDiff.getOps().get(vThis.getOutputOfOp());
if(oThis.getVarControlDeps() == null)
oThis.setVarControlDeps(new ArrayList<String>());
if(!oThis.getVarControlDeps().contains(vCD.getName()))
oThis.getVarControlDeps().add(vCD.getName());
if(vCD.getControlDepsForOp() == null)
vCD.setControlDepsForOp(new ArrayList<String>());
if(!vCD.getControlDepsForOp().contains(oThis.getName()))
vCD.getControlDepsForOp().add(oThis.getName());
} else {
//const/ph -> const/ph case
if(vThis.getControlDeps() == null)
vThis.setControlDeps(new ArrayList<String>());
if(!vThis.getControlDeps().contains(vCD.getName()))
vThis.getControlDeps().add(vCD.getName());
if(vCD.getControlDepsForVar() == null)
vCD.setControlDepsForVar(new ArrayList<String>());
if(!vCD.getControlDepsForVar().contains(vThis.getName()))
vCD.getControlDepsForVar().add(vThis.getName());
}
}
} }
/** /**

View File

@ -16,58 +16,16 @@
package org.nd4j.autodiff.samediff; package org.nd4j.autodiff.samediff;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import java.io.BufferedInputStream; import lombok.*;
import java.io.BufferedOutputStream;
import java.io.DataOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.Stack;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils; import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.lang3.ArrayUtils;
import org.nd4j.autodiff.execution.conf.ExecutorConfiguration; import org.nd4j.autodiff.execution.conf.ExecutorConfiguration;
import org.nd4j.autodiff.execution.conf.OutputMode; import org.nd4j.autodiff.execution.conf.OutputMode;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.functions.DifferentialFunctionFactory; import org.nd4j.autodiff.functions.DifferentialFunctionFactory;
import org.nd4j.autodiff.listeners.At; import org.nd4j.autodiff.listeners.*;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.ListenerResponse;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.listeners.impl.HistoryListener; import org.nd4j.autodiff.listeners.impl.HistoryListener;
import org.nd4j.autodiff.listeners.records.History; import org.nd4j.autodiff.listeners.records.History;
import org.nd4j.autodiff.listeners.records.LossCurve; import org.nd4j.autodiff.listeners.records.LossCurve;
@ -75,34 +33,14 @@ import org.nd4j.autodiff.samediff.config.BatchOutputConfig;
import org.nd4j.autodiff.samediff.config.EvaluationConfig; import org.nd4j.autodiff.samediff.config.EvaluationConfig;
import org.nd4j.autodiff.samediff.config.FitConfig; import org.nd4j.autodiff.samediff.config.FitConfig;
import org.nd4j.autodiff.samediff.config.OutputConfig; import org.nd4j.autodiff.samediff.config.OutputConfig;
import org.nd4j.autodiff.samediff.internal.AbstractSession; import org.nd4j.autodiff.samediff.internal.*;
import org.nd4j.autodiff.samediff.internal.DataTypesSession; import org.nd4j.autodiff.samediff.ops.*;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.autodiff.samediff.ops.SDBaseOps;
import org.nd4j.autodiff.samediff.ops.SDBitwise;
import org.nd4j.autodiff.samediff.ops.SDCNN;
import org.nd4j.autodiff.samediff.ops.SDImage;
import org.nd4j.autodiff.samediff.ops.SDLoss;
import org.nd4j.autodiff.samediff.ops.SDMath;
import org.nd4j.autodiff.samediff.ops.SDNN;
import org.nd4j.autodiff.samediff.ops.SDRNN;
import org.nd4j.autodiff.samediff.ops.SDRandom;
import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper; import org.nd4j.autodiff.samediff.serde.FlatBuffersMapper;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.evaluation.IEvaluation; import org.nd4j.evaluation.IEvaluation;
import org.nd4j.evaluation.classification.Evaluation; import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.evaluation.classification.ROC; import org.nd4j.evaluation.classification.ROC;
import org.nd4j.graph.ExecutionMode; import org.nd4j.graph.*;
import org.nd4j.graph.FlatArray;
import org.nd4j.graph.FlatConfiguration;
import org.nd4j.graph.FlatGraph;
import org.nd4j.graph.FlatNode;
import org.nd4j.graph.FlatVariable;
import org.nd4j.graph.IntPair;
import org.nd4j.graph.OpType;
import org.nd4j.graph.UpdaterState;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace; import org.nd4j.linalg.api.memory.MemoryWorkspace;
@ -112,8 +50,6 @@ import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.api.ops.impl.controlflow.If;
import org.nd4j.linalg.api.ops.impl.controlflow.While;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch; import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray; import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray;
@ -136,7 +72,6 @@ import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.GradientUpdater; import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization; import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicBoolean; import org.nd4j.linalg.primitives.AtomicBoolean;
import org.nd4j.linalg.primitives.AtomicDouble;
import org.nd4j.linalg.primitives.Pair; import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.DeviceLocalNDArray; import org.nd4j.linalg.util.DeviceLocalNDArray;
@ -152,6 +87,17 @@ import org.nd4j.weightinit.impl.NDArraySupplierInitScheme;
import org.nd4j.weightinit.impl.ZeroInitScheme; import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import java.io.*;
import java.lang.reflect.Method;
import java.nio.ByteBuffer;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import static org.nd4j.autodiff.util.TrainingUtils.stackOutputs;
/** /**
* SameDiff is the entrypoint for ND4J's automatic differentiation functionality. * SameDiff is the entrypoint for ND4J's automatic differentiation functionality.
* <p> * <p>
@ -683,7 +629,7 @@ public class SameDiff extends SDBaseOps {
for (val var : variables()) { for (val var : variables()) {
SDVariable clone = var.clone(this); SDVariable clone = var.clone(this);
SDVariable newVar = sameDiff.var(clone); SDVariable newVar = sameDiff.var(clone);
if (var.getArr() != null && var.getVariableType() != VariableType.ARRAY) { //ARRAY type = "activations" - are overwritten anyway if (var.getVariableType() != VariableType.ARRAY && var.getArr() != null ) { //ARRAY type = "activations" - are overwritten anyway
sameDiff.associateArrayWithVariable(var.getArr(), newVar); sameDiff.associateArrayWithVariable(var.getArr(), newVar);
} }
@ -795,9 +741,9 @@ public class SameDiff extends SDBaseOps {
* @param function the function to get the inputs for * @param function the function to get the inputs for
* @return the input ids for a given function * @return the input ids for a given function
*/ */
public String[] getInputsForOp(DifferentialFunction function) { public String[] getInputsForOp(@NonNull DifferentialFunction function) {
if (!ops.containsKey(function.getOwnName())) if (!ops.containsKey(function.getOwnName()))
throw new ND4JIllegalStateException("Illegal function instance id found " + function.getOwnName()); throw new ND4JIllegalStateException("Unknown function instance id found: \"" + function.getOwnName() + "\"");
List<String> inputs = ops.get(function.getOwnName()).getInputsToOp(); List<String> inputs = ops.get(function.getOwnName()).getInputsToOp();
return inputs == null ? null : inputs.toArray(new String[inputs.size()]); return inputs == null ? null : inputs.toArray(new String[inputs.size()]);
} }
@ -1102,12 +1048,8 @@ public class SameDiff extends SDBaseOps {
constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true)); constantArrays.put(variable.getVarName(), new DeviceLocalNDArray(arr, true));
break; break;
case ARRAY: case ARRAY:
// FIXME: remove this before release throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY - arrays for" +
val session = sessions.get(Thread.currentThread().getId()); " this type of variable is calculated ");
val varId = session.newVarId(variable.getVarName(), AbstractSession.OUTER_FRAME, 0, null);
session.getNodeOutputs().put(varId, arr);
//throw new UnsupportedOperationException("Cannot associate array with SDVariable of type ARRAY");
break;
case PLACEHOLDER: case PLACEHOLDER:
//Validate placeholder shapes: //Validate placeholder shapes:
long[] phShape = variable.placeholderShape(); long[] phShape = variable.placeholderShape();
@ -2152,11 +2094,32 @@ public class SameDiff extends SDBaseOps {
requiredVars.addAll(l.requiredVariables(this).trainingVariables()); requiredVars.addAll(l.requiredVariables(this).trainingVariables());
} }
ArrayList<Listener> listenersWitHistory = new ArrayList<>(listeners); List<Listener> listenersWitHistory = new ArrayList<>(listeners);
for(Listener l : this.listeners){
if(!listenersWitHistory.contains(l))
listenersWitHistory.add(l);
}
listenersWitHistory.add(history); listenersWitHistory.add(history);
for (int i = 0; i < numEpochs; i++) {
SameDiff gradInstance = getFunction("grad");
if(gradInstance == null){
createGradFunction();
gradInstance = getFunction("grad");
}
TrainingSession ts = new TrainingSession(gradInstance);
gradInstance.setTrainingConfig(trainingConfig); //In case any listeners want to use it
Set<String> paramsToTrain = new LinkedHashSet<>();
for(Variable v : variables.values()){
if(v.getVariable().getVariableType() == VariableType.VARIABLE){
//TODO not all variable type are needed - i.e., variable that doesn't impact loss should be skipped
paramsToTrain.add(v.getName());
}
}
Loss lastLoss = null;
for (int i = 0; i < numEpochs; i++) {
if (incrementEpochCount && hasListeners) { if (incrementEpochCount && hasListeners) {
at.setEpoch(trainingConfig.getEpochCount()); at.setEpoch(trainingConfig.getEpochCount());
for (Listener l : activeListeners) { for (Listener l : activeListeners) {
@ -2200,153 +2163,38 @@ public class SameDiff extends SDBaseOps {
Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training"); Preconditions.checkState(placeholders.size() > 0, "No placeholder variables were set for training");
resolveVariablesWith(placeholders); resolveVariablesWith(placeholders);
//Calculate gradients: //Call TrainingSession to perform training
execBackwards(placeholders, at.operation(), ds, requiredVars, activeListeners);
//Apply updater:
if (!initializedTraining) if (!initializedTraining)
initializeTraining(); initializeTraining();
Map<Class<?>, AtomicDouble> regScore = null; //Holds regularization scores for later reporting to listeners lastLoss = ts.trainingIteration(
if (hasListeners) { trainingConfig,
regScore = new HashMap<>(); placeholders,
} paramsToTrain,
updaterMap,
ds,
getLossVariables(),
listenersWitHistory,
at);
int iteration = trainingConfig.getIterationCount();
int e = trainingConfig.getEpochCount();
for (Variable v : variables.values()) {
//Only update trainable params - float type parameters (variable type vars)
SDVariable sdv = v.getVariable();
if (sdv.getVariableType() != VariableType.VARIABLE || !sdv.dataType().isFPType())
continue;
INDArray param = sdv.getArr();
SDVariable gradVar = sdv.getGradient();
if (gradVar == null) {
//Not all trainable parameters have gradients defined.
//Consider graph: in1->loss1; in2->loss2, where we optimize only loss1.
//No gradient will be present for in2, because in2 doesn't impact loss1 at all
continue;
}
INDArray grad = gradVar.getArr();
//Note: don't need to divide by minibatch - that should be handled in loss function and hence loss function gradients,
// which should flow through to here
//Pre-apply regularization (L1, L2)
List<Regularization> r = trainingConfig.getRegularization();
int iterCount = trainingConfig.getIterationCount();
int epochCount = trainingConfig.getEpochCount();
double lr = trainingConfig.getUpdater().hasLearningRate() ? trainingConfig.getUpdater().getLearningRate(iteration, epochCount) : 1.0;
if (r != null && r.size() > 0) {
for (Regularization reg : r) {
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
reg.apply(param, grad, lr, iterCount, epochCount);
}
}
}
//Apply updater. Note that we need to reshape to [1,length] for updater
INDArray reshapedView = Shape.newShapeNoCopy(grad, new long[]{1, grad.length()}, grad.ordering() == 'f'); //TODO make sure we always reshape in same order!
Preconditions.checkState(reshapedView != null, "Error reshaping array for parameter \"%s\": array is a view?", sdv);
GradientUpdater u = updaterMap.get(sdv.getVarName());
try {
u.applyUpdater(reshapedView, iteration, e);
} catch (Throwable t) {
throw new RuntimeException("Error applying updater " + u.getClass().getSimpleName() + " to parameter \"" + sdv.getVarName()
+ "\": either parameter size is inconsistent between iterations, or \"" + sdv.getVarName() + "\" should not be a trainable parameter?", t);
}
//Post-apply regularization (weight decay)
if (r != null && r.size() > 0) {
for (Regularization reg : r) {
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
reg.apply(param, grad, lr, iterCount, epochCount);
if (hasListeners) {
double score = reg.score(param, iterCount, epochCount);
if (!regScore.containsKey(reg.getClass())) {
regScore.put(reg.getClass(), new AtomicDouble());
}
regScore.get(reg.getClass()).addAndGet(score);
}
}
}
}
if (hasListeners) {
for (Listener l : activeListeners) {
if (l.isActive(at.operation()))
l.preUpdate(this, at, v, reshapedView);
}
}
if (trainingConfig.isMinimize()) {
param.subi(grad);
} else {
param.addi(grad);
}
}
double[] d = new double[lossVariables.size() + regScore.size()];
List<String> lossVars;
if (regScore.size() > 0) {
lossVars = new ArrayList<>(lossVariables.size() + regScore.size());
lossVars.addAll(lossVariables);
int s = regScore.size();
//Collect regularization losses
for (Map.Entry<Class<?>, AtomicDouble> entry : regScore.entrySet()) {
lossVars.add(entry.getKey().getSimpleName());
d[s] = entry.getValue().get();
}
} else {
lossVars = lossVariables;
}
//Collect the losses...
SameDiff gradFn = sameDiffFunctionInstances.get(GRAD_FN_KEY);
int count = 0;
for (String s : lossVariables) {
INDArray arr = gradFn.getArrForVarName(s);
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
d[count++] = l;
}
Loss loss = new Loss(lossVars, d);
if (lossNames == null) {
lossNames = lossVars;
} else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss names mismatch, expected: %s, got: %s", lossNames, lossVars);
}
if (lossSums == null) { if (lossSums == null) {
lossSums = d; lossSums = lastLoss.getLosses().clone();
} else { } else {
Preconditions.checkState(lossNames.equals(lossVars),
"Loss size mismatch, expected: %s, got: %s", lossSums.length, d.length);
for (int j = 0; j < lossSums.length; j++) { for (int j = 0; j < lossSums.length; j++) {
lossSums[j] += d[j]; lossSums[j] += lastLoss.getLosses()[j];
} }
} }
lossCount++; lossCount++;
if (hasListeners) {
for (Listener l : activeListeners) {
l.iterationDone(this, at, ds, loss);
}
}
trainingConfig.incrementIterationCount(); trainingConfig.incrementIterationCount();
} }
long epochTime = System.currentTimeMillis() - epochStartTime; long epochTime = System.currentTimeMillis() - epochStartTime;
if (incrementEpochCount) { if (incrementEpochCount) {
lossNames = lastLoss.getLossNames();
for (int j = 0; j < lossSums.length; j++) for (int j = 0; j < lossSums.length; j++)
lossSums[j] /= lossCount; lossSums[j] /= lossCount;
@ -2356,14 +2204,13 @@ public class SameDiff extends SDBaseOps {
lossCurve = new LossCurve(lossSums, lossNames); lossCurve = new LossCurve(lossSums, lossNames);
} }
if (incrementEpochCount) { if (incrementEpochCount) {
if (hasListeners) { if (hasListeners) {
boolean doStop = false; boolean doStop = false;
Listener stopped = null; Listener stopped = null;
for (Listener l : activeListeners) { for (Listener l : activeListeners) {
ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime); ListenerResponse res = l.epochEnd(this, at, lossCurve, epochTime);
if (res == ListenerResponse.STOP && (i < numEpochs - 1)) { if (res == ListenerResponse.STOP && (i < numEpochs - 1)) {
@ -2431,7 +2278,6 @@ public class SameDiff extends SDBaseOps {
trainingConfig.incrementEpochCount(); trainingConfig.incrementEpochCount();
} }
if (i < numEpochs - 1) { if (i < numEpochs - 1) {
iter.reset(); iter.reset();
} }
@ -2507,7 +2353,9 @@ public class SameDiff extends SDBaseOps {
INDArray arr = v.getVariable().getArr(); INDArray arr = v.getVariable().getArr();
long stateSize = trainingConfig.getUpdater().stateSize(arr.length()); long stateSize = trainingConfig.getUpdater().stateSize(arr.length());
INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize); INDArray view = stateSize == 0 ? null : Nd4j.createUninitialized(arr.dataType(), 1, stateSize);
updaterMap.put(v.getName(), trainingConfig.getUpdater().instantiate(view, true)); GradientUpdater gu = trainingConfig.getUpdater().instantiate(view, false);
gu.setStateViewArray(view, arr.shape(), arr.ordering(), true);
updaterMap.put(v.getName(), gu);
} }
initializedTraining = true; initializedTraining = true;
@ -3862,7 +3710,8 @@ public class SameDiff extends SDBaseOps {
long thisSize = trainingConfig.getUpdater().stateSize(arr.length()); long thisSize = trainingConfig.getUpdater().stateSize(arr.length());
if (thisSize > 0) { if (thisSize > 0) {
INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize); INDArray stateArr = Nd4j.create(arr.dataType(), 1, thisSize);
GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, true); GradientUpdater u = trainingConfig.getUpdater().instantiate(stateArr, false);
u.setStateViewArray(stateArr, arr.shape(), arr.ordering(), true); //TODO eventually this should be 1 call...
updaterMap.put(v.getVarName(), u); updaterMap.put(v.getVarName(), u);
} else { } else {
GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true); GradientUpdater u = trainingConfig.getUpdater().instantiate((INDArray) null, true);
@ -3946,7 +3795,53 @@ public class SameDiff extends SDBaseOps {
sessions.clear(); sessions.clear();
//Recalculate datatypes of outputs, and dynamically update them //Recalculate datatypes of outputs, and dynamically update them
calculateOutputDataTypes(true); Set<String> allSeenOps = new HashSet<>();
Queue<String> queueOps = new LinkedList<>();
for(String s : dataTypeMap.keySet()){
Variable v = variables.get(s);
v.getVariable().setDataType(dataTypeMap.get(s));
List<String> inToOp = v.getInputsForOp();
if(inToOp != null){
for(String op : inToOp) {
if (!allSeenOps.contains(op)) {
allSeenOps.add(op);
queueOps.add(op);
}
}
}
}
while(!queueOps.isEmpty()){
String op = queueOps.remove();
SameDiffOp o = ops.get(op);
List<String> inVars = o.getInputsToOp();
List<DataType> inDTypes = new ArrayList<>();
if(inVars != null) {
for (String s : inVars) {
SDVariable v = variables.get(s).getVariable();
inDTypes.add(v.dataType());
}
}
List<DataType> outDtypes = o.getOp().calculateOutputDataTypes(inDTypes);
List<String> outVars = o.getOutputsOfOp();
for( int i=0; i<outVars.size(); i++ ){
String varName = outVars.get(i);
Variable var = variables.get(varName);
SDVariable v = var.getVariable();
v.setDataType(outDtypes.get(i));
//Also update queue
if(var.getInputsForOp() != null){
for(String opName : var.getInputsForOp()){
if(!allSeenOps.contains(opName)){
allSeenOps.add(opName);
queueOps.add(opName);
}
}
}
}
}
} }
} }
@ -4097,6 +3992,8 @@ public class SameDiff extends SDBaseOps {
break; break;
} }
} }
variables.get(varName).getInputsForOp().remove(function.getOwnName());
} }
/** /**
@ -4476,11 +4373,7 @@ public class SameDiff extends SDBaseOps {
else if (function instanceof BaseOp) { else if (function instanceof BaseOp) {
SDVariable[] ret = new SDVariable[1]; SDVariable[] ret = new SDVariable[1];
SDVariable checkGet = getVariable(baseName); SDVariable checkGet = getVariable(baseName);
char ordering = 'c';
SDVariable[] args = function.args(); SDVariable[] args = function.args();
if (args != null && args.length > 0 && function.args()[0].getArr() != null) { //Args may be null or length 0 for some ops, like eye
ordering = function.args()[0].getArr().ordering();
}
if (checkGet == null) { if (checkGet == null) {
//Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme //Note: output of an op is ARRAY type - activations, not a trainable parameter. Thus has no weight init scheme
org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0); org.nd4j.linalg.api.buffer.DataType dataType = outputDataTypes.get(0);
@ -4530,45 +4423,6 @@ public class SameDiff extends SDBaseOps {
return sameDiffFunctionInstances.get(functionName); return sameDiffFunctionInstances.get(functionName);
} }
/**
* @deprecated Use {@link SDBaseOps#whileLoop(String[], String, SDVariable[], SameDiffSingleLambda, SameDiffLambda)}
*/
@Deprecated
public While whileStatement(SameDiffConditional sameDiffConditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition loopBody
, SDVariable[] inputVars) {
return While.builder()
.inputVars(inputVars)
.condition(conditionBody)
.predicate(sameDiffConditional)
.trueBody(loopBody)
.parent(this)
.blockName("while-" + UUID.randomUUID().toString())
.build();
}
/**
* @deprecated Use {@link SDBaseOps#ifCond(String, String, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda, SameDiffNoArgSingleLambda)}
*/
@Deprecated
public If ifStatement(SameDiffConditional conditional,
SameDiffFunctionDefinition conditionBody,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody
, SDVariable[] inputVars) {
return If.builder()
.conditionBody(conditionBody)
.falseBody(falseBody)
.trueBody(trueBody)
.predicate(conditional)
.inputVars(inputVars)
.parent(this)
.blockName("if-" + UUID.randomUUID().toString())
.build();
}
/** /**
* Create a new TensorArray. * Create a new TensorArray.
*/ */
@ -4648,6 +4502,51 @@ public class SameDiff extends SDBaseOps {
return execSingle(placeholders, outputs.get(0)); return execSingle(placeholders, outputs.get(0));
} }
/**
* See {@link #calculateGradients(Map, Collection)}
*/
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull String... variables) {
Preconditions.checkArgument(variables.length > 0, "No variables were specified");
return calculateGradients(placeholderVals, Arrays.asList(variables));
}
/**
* Calculate and return the gradients for the specified variables
*
* @param placeholderVals Placeholders. May be null
* @param variables Names of the variables that you want the gradient arrays for
* @return Gradients as a map, keyed by the variable name
*/
public Map<String, INDArray> calculateGradients(Map<String, INDArray> placeholderVals, @NonNull Collection<String> variables) {
Preconditions.checkArgument(!variables.isEmpty(), "No variables were specified");
if (getFunction(GRAD_FN_KEY) == null) {
createGradFunction();
}
List<String> gradVarNames = new ArrayList<>(variables.size());
for (String s : variables) {
Preconditions.checkState(this.variables.containsKey(s), "No variable with name \"%s\" exists in the SameDiff instance", s);
SDVariable v = getVariable(s).getGradient();
if (v != null) {
//In a few cases (like loss not depending on trainable parameters) we won't have gradient array for parameter variable
gradVarNames.add(v.getVarName());
}
}
//Key is gradient variable name
Map<String, INDArray> grads = getFunction(GRAD_FN_KEY).output(placeholderVals, gradVarNames);
Map<String, INDArray> out = new HashMap<>();
for (String s : variables) {
if (getVariable(s).getGradient() != null) {
String gradVar = getVariable(s).getGradient().getVarName();
out.put(s, grads.get(gradVar));
}
}
return out;
}
/** /**
* Create (if required) and then calculate the variable gradients (backward pass) for this graph.<br> * Create (if required) and then calculate the variable gradients (backward pass) for this graph.<br>
* After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}<br> * After execution, the gradient arrays can be accessed using {@code myVariable.getGradient().getArr()}<br>
@ -4660,6 +4559,7 @@ public class SameDiff extends SDBaseOps {
* *
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
*/ */
@Deprecated
public void execBackwards(Map<String, INDArray> placeholders, Operation op) { public void execBackwards(Map<String, INDArray> placeholders, Operation op) {
execBackwards(placeholders, op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList()); execBackwards(placeholders, op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
} }
@ -4669,10 +4569,12 @@ public class SameDiff extends SDBaseOps {
* <p> * <p>
* Uses {@link Operation#INFERENCE}. * Uses {@link Operation#INFERENCE}.
*/ */
@Deprecated
public void execBackwards(Map<String, INDArray> placeholders) { public void execBackwards(Map<String, INDArray> placeholders) {
execBackwards(placeholders, Operation.INFERENCE); execBackwards(placeholders, Operation.INFERENCE);
} }
@Deprecated
protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) { protected void execBackwards(Map<String, INDArray> placeholders, Operation op, MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
if (getFunction(GRAD_FN_KEY) == null) { if (getFunction(GRAD_FN_KEY) == null) {
createGradFunction(); createGradFunction();
@ -4709,6 +4611,7 @@ public class SameDiff extends SDBaseOps {
/** /**
* See {@link #execBackwards(Map, List, Operation)} * See {@link #execBackwards(Map, List, Operation)}
*/ */
@Deprecated
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String... variableGradNamesList) { public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, Operation op, String... variableGradNamesList) {
return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList()); return execBackwards(placeholders, Arrays.asList(variableGradNamesList), op, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
} }
@ -4718,6 +4621,7 @@ public class SameDiff extends SDBaseOps {
* <p> * <p>
* Uses {@link Operation#INFERENCE}. * Uses {@link Operation#INFERENCE}.
*/ */
@Deprecated
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String... variableGradNamesList) { public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, String... variableGradNamesList) {
return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList); return execBackwards(placeholders, Operation.INFERENCE, variableGradNamesList);
} }
@ -4730,6 +4634,7 @@ public class SameDiff extends SDBaseOps {
* @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map * @param placeholders Values for the placeholder variables in the graph. For graphs without placeholders, use null or an empty map
* @param variableGradNamesList Names of the gradient variables to calculate * @param variableGradNamesList Names of the gradient variables to calculate
*/ */
@Deprecated
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) { public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation) {
return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.<String>emptyList(), Collections.<Listener>emptyList()); return execBackwards(placeholders, variableGradNamesList, operation, null, Collections.<String>emptyList(), Collections.<Listener>emptyList());
} }
@ -4739,10 +4644,12 @@ public class SameDiff extends SDBaseOps {
* <p> * <p>
* Uses {@link Operation#INFERENCE}. * Uses {@link Operation#INFERENCE}.
*/ */
@Deprecated
public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) { public Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList) {
return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE); return execBackwards(placeholders, variableGradNamesList, Operation.INFERENCE);
} }
@Deprecated
protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation, protected Map<String, INDArray> execBackwards(Map<String, INDArray> placeholders, List<String> variableGradNamesList, Operation operation,
MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) { MultiDataSet batch, Collection<String> requiredActivations, List<Listener> activeListeners) {
if (getFunction(GRAD_FN_KEY) == null) { if (getFunction(GRAD_FN_KEY) == null) {
@ -5462,7 +5369,7 @@ public class SameDiff extends SDBaseOps {
0, 0,
0, 0,
-1, -1,
0, 0, 0, 0, 0, 0); 0, 0, 0, 0, 0, 0, 0, 0, 0);
return flatNode; return flatNode;
} }
@ -5538,7 +5445,7 @@ public class SameDiff extends SDBaseOps {
val idxForOps = new IdentityHashMap<DifferentialFunction, Integer>(); val idxForOps = new IdentityHashMap<DifferentialFunction, Integer>();
List<SDVariable> allVars = variables(); List<SDVariable> allVars = variables();
for (SDVariable variable : allVars) { for (SDVariable variable : allVars) {
INDArray arr = variable.getArr(); INDArray arr = variable.getVariableType() == VariableType.ARRAY ? null : variable.getArr();
log.trace("Exporting variable: [{}]", variable.getVarName()); log.trace("Exporting variable: [{}]", variable.getVarName());
//If variable is the output of some op - let's use the ONE index for exporting, and properly track the output //If variable is the output of some op - let's use the ONE index for exporting, and properly track the output
@ -5582,7 +5489,26 @@ public class SameDiff extends SDBaseOps {
shape = FlatVariable.createShapeVector(bufferBuilder, shp); shape = FlatVariable.createShapeVector(bufferBuilder, shp);
} }
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape, array, -1, varType); int controlDeps = 0;
int controlDepsForOp = 0;
int controlDepsForVar = 0;
Variable v = variables.get(varName);
int[] cds = FlatBuffersMapper.mapOrNull(v.getControlDeps(), bufferBuilder);
if(cds != null)
controlDeps = FlatVariable.createControlDepsVector(bufferBuilder, cds);
int[] cdsForOp = FlatBuffersMapper.mapOrNull(v.getControlDepsForOp(), bufferBuilder);
if(cdsForOp != null)
controlDepsForOp = FlatVariable.createControlDepForOpVector(bufferBuilder, cdsForOp);
int[] cdsForVar = FlatBuffersMapper.mapOrNull(v.getControlDepsForVar(), bufferBuilder);
if(cdsForVar != null)
controlDepsForVar = FlatVariable.createControlDepsForVarVector(bufferBuilder, cdsForVar);
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(variable.dataType()), shape,
array, -1, varType, controlDeps, controlDepsForOp, controlDepsForVar);
flatVariables.add(flatVariable); flatVariables.add(flatVariable);
} }
@ -5593,43 +5519,6 @@ public class SameDiff extends SDBaseOps {
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId)); flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, variableList, reverseMap, forwardMap, framesMap, idCounter, fnId));
} }
// we're dumping scopes now
for (Map.Entry<String, SameDiff> scope : sameDiffFunctionInstances.entrySet()) {
if (scope.getKey().equalsIgnoreCase(GRAD_FN_KEY)) {
//Skip the gradient function for export
continue;
}
flatNodes.add(asFlatNode(scope.getKey(), scope.getValue(), bufferBuilder));
val currVarList = new ArrayList<SDVariable>(scope.getValue().variables());
// converting all ops from node
for (val node : scope.getValue().variables()) {
INDArray arr = node.getArr();
if (arr == null) {
continue;
}
int name = bufferBuilder.createString(node.getVarName());
int array = arr.toFlatArray(bufferBuilder);
int id = IntPair.createIntPair(bufferBuilder, ++idx, 0);
val pair = parseVariable(node.getVarName());
reverseMap.put(pair.getFirst(), idx);
log.trace("Adding [{}] as [{}]", pair.getFirst(), idx);
byte varType = (byte) node.getVariableType().ordinal();
int flatVariable = FlatVariable.createFlatVariable(bufferBuilder, id, name, FlatBuffersMapper.getDataTypeAsByte(arr.dataType()), 0, array, -1, varType);
flatVariables.add(flatVariable);
}
//add functions
for (SameDiffOp op : scope.getValue().ops.values()) {
DifferentialFunction func = op.getOp();
flatNodes.add(FlatBuffersMapper.asFlatNode(this, func, bufferBuilder, currVarList, reverseMap, forwardMap, framesMap, idCounter, null));
}
}
int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets)); int outputsOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatOffsets));
int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables)); int variablesOffset = FlatGraph.createVariablesVector(bufferBuilder, Ints.toArray(flatVariables));
int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes)); int nodesOffset = FlatGraph.createNodesVector(bufferBuilder, Ints.toArray(flatNodes));
@ -5958,7 +5847,7 @@ public class SameDiff extends SDBaseOps {
vars.add(fg.variables(i)); vars.add(fg.variables(i));
} }
FlatConfiguration conf = fg.configuration(); // FlatConfiguration conf = fg.configuration();
/* Reconstruct the graph /* Reconstruct the graph
We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control We'll do the reconstruction manually here, rather than using sd.var(...), so that we have more control
@ -5995,6 +5884,35 @@ public class SameDiff extends SDBaseOps {
SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null); SDVariable var = new SDVariable(n, vt, sd, shape, dtype, null);
sd.variables.put(n, Variable.builder().name(n).variable(var).build()); sd.variables.put(n, Variable.builder().name(n).variable(var).build());
sd.variableNameToShape.put(n, shape); sd.variableNameToShape.put(n, shape);
Variable v2 = sd.variables.get(n);
//Reconstruct control dependencies
if(v.controlDepsLength() > 0){
int num = v.controlDepsLength();
List<String> l = new ArrayList<>(num);
for( int i=0; i<num; i++ ){
l.add(v.controlDeps(i));
}
v2.setControlDeps(l);
}
if(v.controlDepForOpLength() > 0){
int num = v.controlDepForOpLength();
List<String> l = new ArrayList<>(num);
for( int i=0; i<num; i++ ){
l.add(v.controlDepForOp(i));
}
v2.setControlDepsForOp(l);
}
if(v.controlDepsForVarLength() > 0){
int num = v.controlDepsForVarLength();
List<String> l = new ArrayList<>(num);
for( int i=0; i<num; i++ ){
l.add(v.controlDepsForVar(i));
}
v2.setControlDepsForVar(l);
}
FlatArray fa = v.ndarray(); FlatArray fa = v.ndarray();
@ -6063,7 +5981,37 @@ public class SameDiff extends SDBaseOps {
} }
inputNames[i] = varIn.getVarName(); inputNames[i] = varIn.getVarName();
} }
sd.ops.get(df.getOwnName()).setInputsToOp(Arrays.asList(inputNames)); SameDiffOp op = sd.ops.get(df.getOwnName());
op.setInputsToOp(Arrays.asList(inputNames));
//Reconstruct control dependencies
if (fn.controlDepsLength() > 0) {
int l = fn.controlDepsLength();
List<String> list = new ArrayList<>(l);
for( int i=0; i<l; i++ ){
list.add(fn.controlDeps(i));
}
op.setControlDeps(list);
}
if (fn.varControlDepsLength() > 0) {
int l = fn.varControlDepsLength();
List<String> list = new ArrayList<>(l);
for( int i=0; i<l; i++ ){
list.add(fn.varControlDeps(i));
}
op.setVarControlDeps(list);
}
if (fn.controlDepForLength() > 0) {
int l = fn.controlDepForLength();
List<String> list = new ArrayList<>(l);
for( int i=0; i<l; i++ ){
list.add(fn.controlDepFor(i));
}
op.setControlDepFor(list);
}
//Record that input variables are input to this op //Record that input variables are input to this op
for (String inName : inputNames) { for (String inName : inputNames) {
@ -6072,9 +6020,7 @@ public class SameDiff extends SDBaseOps {
v.setInputsForOp(new ArrayList<String>()); v.setInputsForOp(new ArrayList<String>());
} }
if (!v.getInputsForOp().contains(df.getOwnName())) { if (!v.getInputsForOp().contains(df.getOwnName())) {
v.getInputsForOp( v.getInputsForOp().add(df.getOwnName());
).add(df.getOwnName());
} }
} }
@ -6414,32 +6360,6 @@ public class SameDiff extends SDBaseOps {
return sb.toString(); return sb.toString();
} }
/**
* Calculate data types for the variables in the graph
*/
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes() {
return calculateOutputDataTypes(false);
}
/**
* Calculate data types for the variables in the graph
*/
public Map<String, org.nd4j.linalg.api.buffer.DataType> calculateOutputDataTypes(boolean dynamicUpdate) {
List<String> allVars = new ArrayList<>(variables.keySet());
DataTypesSession session = new DataTypesSession(this, dynamicUpdate);
Map<String, org.nd4j.linalg.api.buffer.DataType> phValues = new HashMap<>();
for (Variable v : variables.values()) {
if (v.getVariable().isPlaceHolder()) {
org.nd4j.linalg.api.buffer.DataType dt = v.getVariable().dataType();
Preconditions.checkNotNull(dt, "Placeholder variable %s has null datatype", v.getName());
phValues.put(v.getName(), dt);
}
}
Map<String, org.nd4j.linalg.api.buffer.DataType> out = session.output(allVars, phValues, null,
Collections.<String>emptyList(), Collections.<Listener>emptyList(), At.defaultAt(Operation.INFERENCE));
return out;
}
/** /**
* For internal use only. * For internal use only.
* Creates a new discinct block name from baseName. * Creates a new discinct block name from baseName.
@ -6470,14 +6390,14 @@ public class SameDiff extends SDBaseOps {
* @return The imported graph * @return The imported graph
*/ */
public static SameDiff importFrozenTF(File graphFile) { public static SameDiff importFrozenTF(File graphFile) {
return TFGraphMapper.getInstance().importGraph(graphFile); return TFGraphMapper.importGraph(graphFile);
} }
/** /**
* See {@link #importFrozenTF(File)} * See {@link #importFrozenTF(File)}
*/ */
public static SameDiff importFrozenTF(GraphDef graphDef) { public static SameDiff importFrozenTF(GraphDef graphDef) {
return TFGraphMapper.getInstance().importGraph(graphDef); return TFGraphMapper.importGraph(graphDef);
} }
@ -6487,7 +6407,7 @@ public class SameDiff extends SDBaseOps {
* Again, the input can be text or binary. * Again, the input can be text or binary.
*/ */
public static SameDiff importFrozenTF(InputStream graph) { public static SameDiff importFrozenTF(InputStream graph) {
return TFGraphMapper.getInstance().importGraph(graph); return TFGraphMapper.importGraph(graph);
} }
@ -6511,7 +6431,7 @@ public class SameDiff extends SDBaseOps {
int start = 1; int start = 1;
// if we already have a name like "op_2", start from trying "op_3" // if we already have a name like "op_2", start from trying "op_3"
if (base.contains("_")) { if (base.contains("_") && base.matches(".*_\\d+")) {
// extract number used to generate base // extract number used to generate base
Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base); Matcher num = Pattern.compile("(.*)_(\\d+)").matcher(base);
// extract argIndex used to generate base // extract argIndex used to generate base

View File

@ -0,0 +1,444 @@
package org.nd4j.autodiff.samediff.internal;
import lombok.Getter;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.function.Predicate;
import org.nd4j.linalg.primitives.Pair;
import java.util.*;
/**
* Object dependency tracker.
* <br>
* Dependency are denoted by: X -> Y, which means "Y depends on X"<br>
* In this implementation:<br>
* - Dependencies may be satisfied, or not satisfied<br>
* - The implementation tracks when the dependency for an object Y are fully satisfied. This occurs when:<br>
* 1. No dependencies X->Y exist<br>
* 2. All dependencies of the form X->Y have been marked as satisfied, via markSatisfied(x)<br>
* - When a dependency is satisfied, any dependent (Ys) are checked to see if all their dependencies are satisfied<br>
* - If a dependent has all dependencies satisfied, it is added to the "new all satisfied" queue for processing,
* which can be accessed via {@link #hasNewAllSatisfied()}, {@link #getNewAllSatisfied()} and {@link #getNewAllSatisfiedList()}<br>
* <br>
* Note: Two types of dependencies exist<br>
* 1. Standard dependencies - i.e., "Y depends on X"<br>
* 2. "Or" dependencies - i.e., "Y depends on (A or B)".<br>
* For Or dependencies of the form "(A or B) -> Y", Y will be marked as "all dependencies satisfied" if either A or B is marked as satisfied.
*
* @param <T> For a dependency X -> Y, Y has type T
* @param <D> For a dependency X -> Y, X has type D
*/
@Slf4j
public abstract class AbstractDependencyTracker<T, D> {
@Getter
private final Map<T, Set<D>> dependencies; //Key: the dependent. Value: all things that the key depends on
@Getter
private final Map<T, Set<Pair<D, D>>> orDependencies; //Key: the dependent. Value: the set of OR dependencies
private final Map<D, Set<T>> reverseDependencies = new HashMap<>(); //Key: the dependee. Value: The set of all dependents that depend on this value
private final Map<D, Set<T>> reverseOrDependencies = new HashMap<>();
private final Set<D> satisfiedDependencies = new HashSet<>(); //Mark the dependency as satisfied. If not in set: assumed to not be satisfied
private final Set<T> allSatisfied; //Set of all dependent values (Ys) that have all dependencies satisfied
private final Queue<T> allSatisfiedQueue = new LinkedList<>(); //Queue for *new* "all satisfied" values. Values are removed using the "new all satisfied" methods
protected AbstractDependencyTracker() {
dependencies = (Map<T, Set<D>>) newTMap();
orDependencies = (Map<T, Set<Pair<D, D>>>) newTMap();
allSatisfied = newTSet();
}
/**
* @return A new map where the dependents (i.e., Y in "X -> Y") are the key
*/
protected abstract Map<T, ?> newTMap();
/**
* @return A new set where the dependents (i.e., Y in "X -> Y") are the key
*/
protected abstract Set<T> newTSet();
/**
* @return A String representation of the dependent object
*/
protected abstract String toStringT(T t);
/**
* @return A String representation of the dependee object
*/
protected abstract String toStringD(D d);
/**
* Clear all internal state for the dependency tracker
*/
public void clear() {
dependencies.clear();
orDependencies.clear();
reverseDependencies.clear();
reverseOrDependencies.clear();
satisfiedDependencies.clear();
allSatisfied.clear();
allSatisfiedQueue.clear();
}
/**
* @return True if no dependencies have been defined
*/
public boolean isEmpty() {
return dependencies.isEmpty() && orDependencies.isEmpty() &&
allSatisfiedQueue.isEmpty();
}
/**
* @return True if the dependency has been marked as satisfied using {@link #markSatisfied(Object, boolean)}
*/
public boolean isSatisfied(@NonNull D x) {
return satisfiedDependencies.contains(x);
}
/**
* Mark the specified value as satisfied.
* For example, if two dependencies have been previously added (X -> Y) and (X -> A) then after the markSatisfied(X, true)
* call, both of these dependencies are considered satisfied.
*
* @param x Value to mark
* @param satisfied Whether to mark as satisfied (true) or unsatisfied (false)
*/
public void markSatisfied(@NonNull D x, boolean satisfied) {
if (satisfied) {
boolean alreadySatisfied = satisfiedDependencies.contains(x);
if (!alreadySatisfied) {
satisfiedDependencies.add(x);
//Check if any Y's exist that have dependencies that are all satisfied, for X -> Y
Set<T> s = reverseDependencies.get(x);
Set<T> s2 = reverseOrDependencies.get(x);
Set<T> set;
if (s != null && s2 != null) {
set = newTSet();
set.addAll(s);
set.addAll(s2);
} else if (s != null) {
set = s;
} else if (s2 != null) {
set = s2;
} else {
if (log.isTraceEnabled()) {
log.trace("No values depend on: {}", toStringD(x));
}
return;
}
for (T t : set) {
Set<D> required = dependencies.get(t);
Set<Pair<D, D>> requiredOr = orDependencies.get(t);
boolean allSatisfied = true;
if (required != null) {
for (D d : required) {
if (!isSatisfied(d)) {
allSatisfied = false;
break;
}
}
}
if (allSatisfied && requiredOr != null) {
for (Pair<D, D> p : requiredOr) {
if (!isSatisfied(p.getFirst()) && !isSatisfied(p.getSecond())) {
allSatisfied = false;
break;
}
}
}
if (allSatisfied) {
if (!this.allSatisfied.contains(t)) {
this.allSatisfied.add(t);
this.allSatisfiedQueue.add(t);
}
}
}
}
} else {
satisfiedDependencies.remove(x);
if (!allSatisfied.isEmpty()) {
Set<T> reverse = reverseDependencies.get(x);
if (reverse != null) {
for (T y : reverse) {
if (allSatisfied.contains(y)) {
allSatisfied.remove(y);
allSatisfiedQueue.remove(y);
}
}
}
Set<T> orReverse = reverseOrDependencies.get(x);
if (orReverse != null) {
for (T y : orReverse) {
if (allSatisfied.contains(y) && !isAllSatisfied(y)) {
allSatisfied.remove(y);
allSatisfiedQueue.remove(y);
}
}
}
}
}
}
/**
* Check whether any dependencies x -> y exist, for y (i.e., anything previously added by {@link #addDependency(Object, Object)}
* or {@link #addOrDependency(Object, Object, Object)}
*
* @param y Dependent to check
* @return True if Y depends on any values
*/
public boolean hasDependency(@NonNull T y) {
Set<D> s1 = dependencies.get(y);
if (s1 != null && !s1.isEmpty())
return true;
Set<Pair<D, D>> s2 = orDependencies.get(y);
return s2 != null && !s2.isEmpty();
}
/**
* Get all dependencies x, for x -> y, and (x1 or x2) -> y
*
* @param y Dependent to get dependencies for
* @return List of dependencies
*/
public DependencyList<T, D> getDependencies(@NonNull T y) {
Set<D> s1 = dependencies.get(y);
Set<Pair<D, D>> s2 = orDependencies.get(y);
List<D> l1 = (s1 == null ? null : new ArrayList<>(s1));
List<Pair<D, D>> l2 = (s2 == null ? null : new ArrayList<>(s2));
return new DependencyList<>(y, l1, l2);
}
/**
* Add a dependency: y depends on x, as in x -> y
*
* @param y The dependent
* @param x The dependee that is required for Y
*/
public void addDependency(@NonNull T y, @NonNull D x) {
if (!dependencies.containsKey(y))
dependencies.put(y, new HashSet<D>());
if (!reverseDependencies.containsKey(x))
reverseDependencies.put(x, newTSet());
dependencies.get(y).add(x);
reverseDependencies.get(x).add(y);
checkAndUpdateIfAllSatisfied(y);
}
protected void checkAndUpdateIfAllSatisfied(@NonNull T y) {
boolean allSat = isAllSatisfied(y);
if (allSat) {
//Case where "x is satisfied" happened before x->y added
if (!allSatisfied.contains(y)) {
allSatisfied.add(y);
allSatisfiedQueue.add(y);
}
} else if (allSatisfied.contains(y)) {
if (!allSatisfiedQueue.contains(y)) {
StringBuilder sb = new StringBuilder();
sb.append("Dependent object \"").append(toStringT(y)).append("\" was previously processed after all dependencies")
.append(" were marked satisfied, but is now additional dependencies have been added.\n");
DependencyList<T, D> dl = getDependencies(y);
if (dl.getDependencies() != null) {
sb.append("Dependencies:\n");
for (D d : dl.getDependencies()) {
sb.append(d).append(" - ").append(isSatisfied(d) ? "Satisfied" : "Not satisfied").append("\n");
}
}
if (dl.getOrDependencies() != null) {
sb.append("Or dependencies:\n");
for (Pair<D, D> p : dl.getOrDependencies()) {
sb.append(p).append(" - satisfied=(").append(isSatisfied(p.getFirst())).append(",").append(isSatisfied(p.getSecond())).append(")");
}
}
throw new IllegalStateException(sb.toString());
}
//Not satisfied, but is in the queue -> needs to be removed
allSatisfied.remove(y);
allSatisfiedQueue.remove(y);
}
}
protected boolean isAllSatisfied(@NonNull T y) {
Set<D> set1 = dependencies.get(y);
boolean allSatisfied = true;
if (set1 != null) {
for (D d : set1) {
allSatisfied = isSatisfied(d);
if (!allSatisfied)
break;
}
}
if (allSatisfied) {
Set<Pair<D, D>> set2 = orDependencies.get(y);
if (set2 != null) {
for (Pair<D, D> p : set2) {
allSatisfied = isSatisfied(p.getFirst()) || isSatisfied(p.getSecond());
if (!allSatisfied)
break;
}
}
}
return allSatisfied;
}
/**
* Remove a dependency (x -> y)
*
* @param y The dependent that currently requires X
* @param x The dependee that is no longer required for Y
*/
public void removeDependency(@NonNull T y, @NonNull D x) {
if (!dependencies.containsKey(y) && !orDependencies.containsKey(y))
return;
Set<D> s = dependencies.get(y);
if (s != null) {
s.remove(x);
if (s.isEmpty())
dependencies.remove(y);
}
Set<T> s2 = reverseDependencies.get(x);
if (s2 != null) {
s2.remove(y);
if (s2.isEmpty())
reverseDependencies.remove(x);
}
Set<Pair<D, D>> s3 = orDependencies.get(y);
if (s3 != null) {
boolean removedReverse = false;
Iterator<Pair<D, D>> iter = s3.iterator();
while (iter.hasNext()) {
Pair<D, D> p = iter.next();
if (x.equals(p.getFirst()) || x.equals(p.getSecond())) {
iter.remove();
if (!removedReverse) {
Set<T> set1 = reverseOrDependencies.get(p.getFirst());
Set<T> set2 = reverseOrDependencies.get(p.getSecond());
set1.remove(y);
set2.remove(y);
if (set1.isEmpty())
reverseOrDependencies.remove(p.getFirst());
if (set2.isEmpty())
reverseOrDependencies.remove(p.getSecond());
removedReverse = true;
}
}
}
}
if (s3 != null && s3.isEmpty())
orDependencies.remove(y);
}
/**
* Add an "Or" dependency: Y requires either x1 OR x2 - i.e., (x1 or x2) -> Y<br>
* If either x1 or x2 (or both) are marked satisfied via {@link #markSatisfied(Object, boolean)} then the
* dependency is considered satisfied
*
* @param y Dependent
* @param x1 Dependee 1
* @param x2 Dependee 2
*/
public void addOrDependency(@NonNull T y, @NonNull D x1, @NonNull D x2) {
if (!orDependencies.containsKey(y))
orDependencies.put(y, new HashSet<Pair<D, D>>());
if (!reverseOrDependencies.containsKey(x1))
reverseOrDependencies.put(x1, newTSet());
if (!reverseOrDependencies.containsKey(x2))
reverseOrDependencies.put(x2, newTSet());
orDependencies.get(y).add(new Pair<>(x1, x2));
reverseOrDependencies.get(x1).add(y);
reverseOrDependencies.get(x2).add(y);
checkAndUpdateIfAllSatisfied(y);
}
/**
* @return True if there are any new/unprocessed "all satisfied dependents" (Ys in X->Y)
*/
public boolean hasNewAllSatisfied() {
return !allSatisfiedQueue.isEmpty();
}
/**
* Returns the next new dependent (Y in X->Y) that has all dependees (Xs) marked as satisfied via {@link #markSatisfied(Object, boolean)}
* Throws an exception if {@link #hasNewAllSatisfied()} returns false.<br>
* Note that once a value has been retrieved from here, no new dependencies of the form (X -> Y) can be added for this value;
* the value is considered "processed" at this point.
*
* @return The next new "all satisfied dependent"
*/
public T getNewAllSatisfied() {
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
return allSatisfiedQueue.remove();
}
/**
* @return As per {@link #getNewAllSatisfied()} but returns all values
*/
public List<T> getNewAllSatisfiedList() {
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
List<T> ret = new ArrayList<>(allSatisfiedQueue);
allSatisfiedQueue.clear();
return ret;
}
/**
* As per {@link #getNewAllSatisfied()} but instead of returning the first dependee, it returns the first that matches
* the provided predicate. If no value matches the predicate, null is returned
*
* @param predicate Predicate gor checking
* @return The first value matching the predicate, or null if no values match the predicate
*/
public T getFirstNewAllSatisfiedMatching(@NonNull Predicate<T> predicate) {
Preconditions.checkState(hasNewAllSatisfied(), "No new/unprocessed dependents that are all satisfied");
T t = allSatisfiedQueue.peek();
if (predicate.test(t)) {
t = allSatisfiedQueue.remove();
allSatisfied.remove(t);
return t;
}
if (allSatisfiedQueue.size() > 1) {
Iterator<T> iter = allSatisfiedQueue.iterator();
while (iter.hasNext()) {
t = iter.next();
if (predicate.test(t)) {
iter.remove();
allSatisfied.remove(t);
return t;
}
}
}
return null; //None match predicate
}
}

View File

@ -1,107 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2019 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.autodiff.samediff.internal;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.nd4j.linalg.dataset.api.MultiDataSet;
/**
* Infer datatypes for all variables.
* Optionally update the datatypes of variables as we go
*/
public class DataTypesSession extends AbstractSession<DataType, DataTypesSession.DataTypeCalc> {
protected boolean dynamicUpdate;
/**
* @param sameDiff SameDiff instance
* @param dynamicUpdate If true: Dynamically update the datatypes as we go
*/
public DataTypesSession(SameDiff sameDiff, boolean dynamicUpdate) {
super(sameDiff);
this.dynamicUpdate = dynamicUpdate;
}
@Override
public DataType getConstantOrVariable(String variableName) {
//Variables and constants should always have datatype available
DataType dt = sameDiff.getVariable(variableName).dataType();
Preconditions.checkNotNull(dt, "No datatype available for variable %s", variableName);
return dt;
}
@Override
public DataTypeCalc getAndParameterizeOp(String opName, FrameIter frameIter, Set<VarId> inputs, Set<VarId> allIterInputs, Set<String> constAndPhInputs, Map<String, DataType> placeholderValues) {
DifferentialFunction df = sameDiff.getOpById(opName);
List<DataType> inputDataTypes = new ArrayList<>();
for(SDVariable v : df.args()){
DataType dt = v.dataType();
if(dt != null){
inputDataTypes.add(dt);
} else {
String s = v.getVarName();
for(VarId vid : inputs){
if(vid.getVariable().equals(s)){
DataType dt2 = nodeOutputs.get(vid);
Preconditions.checkNotNull(dt2, "No datatype for %s", vid);
inputDataTypes.add(dt2);
}
}
}
}
return new DataTypeCalc(df, inputDataTypes);
}
@Override
public DataType[] getOutputs(DataTypeCalc op, FrameIter outputFrameIter, Set<VarId> inputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch) {
List<DataType> outTypes = op.getFn().calculateOutputDataTypes(op.getInputTypes());
if(dynamicUpdate) {
SDVariable[] fnOutputs = op.getFn().outputVariables();
for( int i=0; i<fnOutputs.length; i++ ){
SDVariable v = fnOutputs[i];
DataType d = outTypes.get(i);
if(v.dataType() != d){
v.setDataType(d);
}
}
}
return outTypes.toArray(new DataType[outTypes.size()]);
}
@AllArgsConstructor
@Data
protected static class DataTypeCalc {
protected final DifferentialFunction fn;
protected final List<DataType> inputTypes;
}
}

View File

@ -0,0 +1,20 @@
package org.nd4j.autodiff.samediff.internal;
import lombok.AllArgsConstructor;
import lombok.Data;
import org.nd4j.linalg.primitives.Pair;
import java.util.List;
/**
* A list of dependencies, used in {@link AbstractDependencyTracker}
*
* @author Alex Black
*/
@Data
@AllArgsConstructor
public class DependencyList<T, D> {
private T dependencyFor;
private List<D> dependencies;
private List<Pair<D, D>> orDependencies;
}

View File

@ -0,0 +1,38 @@
package org.nd4j.autodiff.samediff.internal;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.primitives.Pair;
import java.util.*;
/**
* Dependenci tracker. See {@link AbstractDependencyTracker} for details
*
* @param <T> For a dependency X -> Y, Y has type T
* @param <D> For a dependency X -> Y, X has type D
*/
@Slf4j
public class DependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
@Override
protected Map<T, ?> newTMap() {
return new HashMap<>();
}
@Override
protected Set<T> newTSet() {
return new HashSet<>();
}
@Override
protected String toStringT(T t) {
return t.toString();
}
@Override
protected String toStringD(D d) {
return d.toString();
}
}

View File

@ -0,0 +1,44 @@
package org.nd4j.autodiff.samediff.internal;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.primitives.Pair;
import java.util.*;
/**
* Object dependency tracker, using object identity (not object equality) for the Ys (of type T)<br>
* See {@link AbstractDependencyTracker} for more details
*
* @author Alex Black
*/
@Slf4j
public class IdentityDependencyTracker<T, D> extends AbstractDependencyTracker<T,D> {
@Override
protected Map<T, ?> newTMap() {
return new IdentityHashMap<>();
}
@Override
protected Set<T> newTSet() {
return Collections.newSetFromMap(new IdentityHashMap<T, Boolean>());
}
@Override
protected String toStringT(T t) {
if(t instanceof INDArray){
INDArray i = (INDArray)t;
return System.identityHashCode(t) + " - id=" + i.getId() + ", " + i.shapeInfoToString();
} else {
return System.identityHashCode(t) + " - " + t.toString();
}
}
@Override
protected String toStringD(D d) {
return d.toString();
}
}

View File

@ -30,8 +30,10 @@ import java.util.List;
@Builder @Builder
public class SameDiffOp { public class SameDiffOp {
protected String name; protected String name;
protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set) protected DifferentialFunction op; //Actual op (note: should be mutable: i.e., cloneable, no arrays set)
protected List<String> inputsToOp; //Name of SDVariables as input protected List<String> inputsToOp; //Name of SDVariables as input
protected List<String> outputsOfOp; //Name of SDVariables as output protected List<String> outputsOfOp; //Name of SDVariables as output
protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec) protected List<String> controlDeps; //Name of SDVariables as control dependencies (not data inputs, but need to be available before exec)
protected List<String> varControlDeps; //Variables (constants, placeholders, etc) that are control dependencies for this op
protected List<String> controlDepFor; //Name of the variables that this op is a control dependency for
} }

View File

@ -0,0 +1,60 @@
package org.nd4j.autodiff.samediff.internal;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import java.io.Closeable;
/**
* SessionMemMgr - aka "Session Memory Manager" is responsible for allocating, managing, and deallocating memory used
* during SameDiff execution.<br>
* This interface allows different memory management strategies to be used, abstracted away from the actual graph
* execution logic
*
* @author Alex Black
*/
public interface SessionMemMgr extends Closeable {
/**
* Allocate an array with the specified datatype and shape.<br>
* NOTE: This array should be assumed to be uninitialized - i.e., contains random values.
*
* @param detached If true: the array is safe to return outside of the SameDiff session run (for example, the array
* is one that may be returned to the user)
* @param dataType Datatype of the returned array
* @param shape Array shape
* @return The newly allocated (uninitialized) array
*/
INDArray allocate(boolean detached, DataType dataType, long... shape);
/**
* As per {@link #allocate(boolean, DataType, long...)} but from a LongShapeDescriptor instead
*/
INDArray allocate(boolean detached, LongShapeDescriptor descriptor);
/**
* Allocate an uninitialized array with the same datatype and shape as the specified array
*/
INDArray ulike(INDArray arr);
/**
* Duplicate the specified array, to an array that is managed/allocated by the session memory manager
*/
INDArray dup(INDArray arr);
/**
* Release the array. All arrays allocated via one of the allocate methods should be returned here once they are no
* longer used, and all references to them should be cleared.
* After calling release, anything could occur to the array - deallocated, workspace closed, reused, etc.
*
* @param array The array that can be released
*/
void release(INDArray array);
/**
* Close the session memory manager and clean up any memory / resources, if any
*/
void close();
}

View File

@ -0,0 +1,232 @@
package org.nd4j.autodiff.samediff.internal;
import com.sun.prism.paint.Gradient;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.listeners.At;
import org.nd4j.autodiff.listeners.Listener;
import org.nd4j.autodiff.listeners.Loss;
import org.nd4j.autodiff.listeners.Operation;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.TrainingConfig;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.learning.GradientUpdater;
import org.nd4j.linalg.learning.regularization.Regularization;
import org.nd4j.linalg.primitives.AtomicDouble;
import java.util.*;
/**
* TrainingSession extends InferenceSession, to add training-specific functionality:<br>
* - Application of regularization (L1, L2, weight decay etc)<br>
* - Inline updating of variables, using updater/optimizer (Adam, Nesterov, SGD, etc)<br>
* - Calculation of regularization scores (Score for L1, L2, etc)
*
* @author Alex Black
*/
@Slf4j
public class TrainingSession extends InferenceSession {
protected TrainingConfig config;
protected Map<String, String> gradVarToVarMap;
protected Map<String, GradientUpdater> updaters;
protected Map<String, Integer> lossVarsToLossIdx;
protected double[] currIterLoss;
protected Map<Class<?>, AtomicDouble> currIterRegLoss;
protected List<Listener> listeners;
public TrainingSession(SameDiff sameDiff) {
super(sameDiff);
}
/**
* Perform one iteration of training - i.e., do forward and backward passes, and update the parameters
*
* @param config Training configuration
* @param placeholders Current placeholders
* @param paramsToTrain Set of parameters that will be trained
* @param updaters Current updater state
* @param batch Current data/batch (mainly for listeners, should have already been converted to placeholders map)
* @param lossVariables Loss variables (names)
* @param listeners Listeners (if any)
* @param at Current epoch, iteration, etc
* @return The Loss at the current iteration
*/
public Loss trainingIteration(TrainingConfig config, Map<String, INDArray> placeholders, Set<String> paramsToTrain, Map<String, GradientUpdater> updaters,
MultiDataSet batch, List<String> lossVariables, List<Listener> listeners, At at) {
this.config = config;
this.updaters = updaters;
//Preprocess listeners, get the relevant ones
if (listeners == null) {
this.listeners = null;
} else {
List<Listener> filtered = new ArrayList<>();
for (Listener l : listeners) {
if (l.isActive(at.operation())) {
filtered.add(l);
}
}
this.listeners = filtered.isEmpty() ? null : filtered;
}
List<String> requiredActivations = new ArrayList<>();
gradVarToVarMap = new HashMap<>(); //Key: gradient variable. Value: variable that the key is gradient for
for (String s : paramsToTrain) {
Preconditions.checkState(sameDiff.hasVariable(s), "SameDiff instance does not have a variable with name \"%s\"", s);
SDVariable v = sameDiff.getVariable(s);
Preconditions.checkState(v.getVariableType() == VariableType.VARIABLE, "Can only train VARIABLE type variable - \"%s\" has type %s",
s, v.getVariableType());
SDVariable grad = sameDiff.getVariable(s).getGradient();
if (grad == null) {
//In some cases, a variable won't actually impact the loss value, and hence won't have a gradient associated with it
//For example: floatVar -> cast to integer -> cast to float -> sum -> loss
//In this case, the gradient of floatVar isn't defined (due to no floating point connection to the loss)
continue;
}
requiredActivations.add(grad.getVarName());
gradVarToVarMap.put(grad.getVarName(), s);
}
//Set up losses
lossVarsToLossIdx = new LinkedHashMap<>();
List<String> lossVars;
currIterLoss = new double[lossVariables.size()];
currIterRegLoss = new HashMap<>();
for (int i = 0; i < lossVariables.size(); i++) {
lossVarsToLossIdx.put(lossVariables.get(i), i);
}
//Do training iteration
List<String> outputVars = new ArrayList<>(gradVarToVarMap.keySet()); //TODO this should be empty, and grads calculated in requiredActivations
Map<String, INDArray> m = output(outputVars, placeholders, batch, requiredActivations, listeners, at);
double[] finalLoss = new double[currIterLoss.length + currIterRegLoss.size()];
System.arraycopy(currIterLoss, 0, finalLoss, 0, currIterLoss.length);
if (currIterRegLoss.size() > 0) {
lossVars = new ArrayList<>(lossVariables.size() + currIterRegLoss.size());
lossVars.addAll(lossVariables);
int s = currIterRegLoss.size();
//Collect regularization losses
for (Map.Entry<Class<?>, AtomicDouble> entry : currIterRegLoss.entrySet()) {
lossVars.add(entry.getKey().getSimpleName());
finalLoss[s] = entry.getValue().get();
}
} else {
lossVars = lossVariables;
}
Loss loss = new Loss(lossVars, finalLoss);
if (listeners != null) {
for (Listener l : listeners) {
if (l.isActive(Operation.TRAINING)) {
l.iterationDone(sameDiff, at, batch, loss);
}
}
}
return loss;
}
@Override
public INDArray[] getOutputs(SameDiffOp op, FrameIter outputFrameIter, Set<VarId> opInputs, Set<VarId> allIterInputs,
Set<String> constAndPhInputs, List<Listener> listeners, At at, MultiDataSet batch, Set<String> allReqVariables) {
//Get outputs from InferenceSession
INDArray[] out = super.getOutputs(op, outputFrameIter, opInputs, allIterInputs, constAndPhInputs, listeners, at, batch, allReqVariables);
List<String> outputs = op.getOutputsOfOp();
int outIdx = 0;
for (String s : outputs) {
//If this is a loss variable - record it
if (lossVarsToLossIdx.containsKey(s)) {
int lossIdx = lossVarsToLossIdx.get(s);
INDArray arr = out[outIdx];
double l = arr.isScalar() ? arr.getDouble(0) : arr.sumNumber().doubleValue();
currIterLoss[lossIdx] += l;
}
//If this is a gradient variable - apply the updater and update the parameter array in-line
if (gradVarToVarMap.containsKey(s)) {
String varName = gradVarToVarMap.get(s);
//log.info("Calculated gradient for variable \"{}\": (grad var name: \"{}\")", varName, s);
Variable gradVar = sameDiff.getVariables().get(s);
if (gradVar.getInputsForOp() != null && gradVar.getInputsForOp().isEmpty()) {
//Should be rare, and we should handle this by tracking dependencies, and only update when safe
// (i.e., dependency tracking)
throw new IllegalStateException("Op depends on gradient variable: " + s + " for variable " + varName);
}
GradientUpdater u = updaters.get(varName);
Preconditions.checkState(u != null, "No updater found for variable \"%s\"", varName);
Variable var = sameDiff.getVariables().get(varName);
INDArray gradArr = out[outIdx];
INDArray paramArr = var.getVariable().getArr();
//Pre-updater regularization (L1, L2)
List<Regularization> r = config.getRegularization();
if (r != null && r.size() > 0) {
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
for (Regularization reg : r) {
if (reg.applyStep() == Regularization.ApplyStep.BEFORE_UPDATER) {
if (this.listeners != null) {
double score = reg.score(paramArr, at.iteration(), at.epoch());
if (!currIterRegLoss.containsKey(reg.getClass())) {
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
}
currIterRegLoss.get(reg.getClass()).addAndGet(score);
}
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
}
}
}
u.applyUpdater(gradArr, at.iteration(), at.epoch());
//Post-apply regularization (weight decay)
if (r != null && r.size() > 0) {
double lr = config.getUpdater().hasLearningRate() ? config.getUpdater().getLearningRate(at.iteration(), at.epoch()) : 1.0;
for (Regularization reg : r) {
if (reg.applyStep() == Regularization.ApplyStep.POST_UPDATER) {
if (this.listeners != null) {
double score = reg.score(paramArr, at.iteration(), at.epoch());
if (!currIterRegLoss.containsKey(reg.getClass())) {
currIterRegLoss.put(reg.getClass(), new AtomicDouble());
}
currIterRegLoss.get(reg.getClass()).addAndGet(score);
}
reg.apply(paramArr, gradArr, lr, at.iteration(), at.epoch());
}
}
}
if (listeners != null) {
for (Listener l : listeners) {
if (l.isActive(at.operation()))
l.preUpdate(sameDiff, at, var, gradArr);
}
}
//Update:
if (config.isMinimize()) {
paramArr.subi(gradArr);
} else {
paramArr.addi(gradArr);
}
log.trace("Applied updater to gradient and updated variable: {}", varName);
}
outIdx++;
}
return out;
}
}

View File

@ -35,8 +35,7 @@ public class Variable {
protected List<String> controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list protected List<String> controlDepsForOp; //if a op control dependency (x -> opY) exists, then "opY" will be in this list
protected List<String> controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list protected List<String> controlDepsForVar; //if a variable control dependency (x -> varY) exists, then "varY" will be in this list
protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of protected String outputOfOp; //Null for placeholders/constants. For array type SDVariables, the name of the op it's an output of
protected List<String> controlDeps; //Control dependencies: name of variables that must be available before this variable is considered available for execution protected List<String> controlDeps; //Control dependencies: name of ops that must be available before this variable is considered available for execution
protected int outputOfOpIdx; //Index of the output for the op (say, variable is output number 2 of op "outputOfOp")
protected SDVariable gradient; //Variable corresponding to the gradient of this variable protected SDVariable gradient; //Variable corresponding to the gradient of this variable
protected int variableIndex = -1; protected int variableIndex = -1;
} }

View File

@ -0,0 +1,25 @@
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.linalg.api.ndarray.INDArray;
/**
* Abstract memory manager, that implements ulike and dup methods using the underlying allocate methods
*
* @author Alex Black
*/
public abstract class AbstractMemoryMgr implements SessionMemMgr {
@Override
public INDArray ulike(@NonNull INDArray arr) {
return allocate(false, arr.dataType(), arr.shape());
}
@Override
public INDArray dup(@NonNull INDArray arr) {
INDArray out = ulike(arr);
out.assign(arr);
return out;
}
}

View File

@ -0,0 +1,43 @@
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
/**
* A simple memory management strategy that deallocates memory as soon as it is no longer needed.<br>
* This should result in a minimal amount of memory, but will have some overhead - notably, the cost of deallocating
* and reallocating memory all the time.
*
* @author Alex Black
*/
@Slf4j
public class ArrayCloseMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
return Nd4j.createUninitialized(dataType, shape);
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return Nd4j.create(descriptor, false);
}
@Override
public void release(@NonNull INDArray array) {
if (!array.wasClosed() && array.closeable()) {
array.close();
log.trace("Closed array (deallocated) - id={}", array.getId());
}
}
@Override
public void close() {
//No-op
}
}

View File

@ -0,0 +1,168 @@
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.NonNull;
import lombok.extern.slf4j.Slf4j;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.DependencyList;
import org.nd4j.autodiff.samediff.internal.IdentityDependencyTracker;
import org.nd4j.autodiff.samediff.internal.InferenceSession;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.primitives.Pair;
import java.util.*;
/**
* A {@link SessionMemMgr} that wraps an existing memory manager, to ensure that:<br>
* - All arrays that are supposed to be closed, have been closed<br>
* - Arrays are only passed to the close method exactly one (unless they are requested outputs)<br>
* - Arrays that are passed to the close method were originally allocated by the session memory manager<br>
* <br>
* How to use:<br>
* 1. Perform an inference or training iteration, as normal<br>
* 2. Call {@link #assertAllReleasedExcept(Collection)} with the output arrays<br>
* <p>
* NOTE: This is intended for debugging and testing only
*
* @author Alex Black
*/
@Slf4j
public class CloseValidationMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
private final SameDiff sd;
private final SessionMemMgr underlying;
private final Map<INDArray, Boolean> released = new IdentityHashMap<>();
public CloseValidationMemoryMgr(SameDiff sd, SessionMemMgr underlying) {
this.sd = sd;
this.underlying = underlying;
}
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
INDArray out = underlying.allocate(detached, dataType, shape);
released.put(out, false);
return out;
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
INDArray out = underlying.allocate(detached, descriptor);
released.put(out, false);
return out;
}
@Override
public void release(INDArray array) {
Preconditions.checkState(released.containsKey(array), "Attempting to release an array that was not allocated by" +
" this memory manager: id=%s", array.getId());
if (released.get(array)) {
//Already released
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
DependencyList<INDArray, InferenceSession.Dep> dl = arrayUseTracker.getDependencies(array);
System.out.println(dl);
if (dl.getDependencies() != null) {
for (InferenceSession.Dep d : dl.getDependencies()) {
System.out.println(d + ": " + arrayUseTracker.isSatisfied(d));
}
}
if (dl.getOrDependencies() != null) {
for (Pair<InferenceSession.Dep, InferenceSession.Dep> p : dl.getOrDependencies()) {
System.out.println(p + " - (" + arrayUseTracker.isSatisfied(p.getFirst()) + "," + arrayUseTracker.isSatisfied(p.getSecond()));
}
}
}
Preconditions.checkState(!released.get(array), "Attempting to release an array that was already deallocated by" +
" an earlier release call to this memory manager: id=%s", array.getId());
log.trace("Released array: id = {}", array.getId());
released.put(array, true);
}
@Override
public void close() {
underlying.close();
}
/**
* Check that all arrays have been released (after an inference call) except for the specified arrays.
*
* @param except Arrays that should not have been closed (usually network outputs)
*/
public void assertAllReleasedExcept(@NonNull Collection<INDArray> except) {
Set<INDArray> allVarPhConst = null;
for (INDArray arr : except) {
if (!released.containsKey(arr)) {
//Check if constant, variable or placeholder - maybe user requested that out
if (allVarPhConst == null)
allVarPhConst = identitySetAllConstPhVar();
if (allVarPhConst.contains(arr))
continue; //OK - output is a constant, variable or placeholder, hence it's fine it's not allocated by the memory manager
throw new IllegalStateException("Array " + arr.getId() + " was not originally allocated by the memory manager");
}
boolean released = this.released.get(arr);
if (released) {
throw new IllegalStateException("Specified output array (id=" + arr.getId() + ") should not have been deallocated but was");
}
}
Set<INDArray> exceptSet = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
exceptSet.addAll(except);
int numNotClosed = 0;
Set<INDArray> notReleased = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
InferenceSession is = sd.getSessions().get(Thread.currentThread().getId());
IdentityDependencyTracker<INDArray, InferenceSession.Dep> arrayUseTracker = is.getArrayUseTracker();
for (Map.Entry<INDArray, Boolean> e : released.entrySet()) {
INDArray a = e.getKey();
if (!exceptSet.contains(a)) {
boolean b = e.getValue();
if (!b) {
notReleased.add(a);
numNotClosed++;
log.info("Not released: array id {}", a.getId());
DependencyList<INDArray, InferenceSession.Dep> list = arrayUseTracker.getDependencies(a);
List<InferenceSession.Dep> l = list.getDependencies();
List<Pair<InferenceSession.Dep, InferenceSession.Dep>> l2 = list.getOrDependencies();
if (l != null) {
for (InferenceSession.Dep d : l) {
if (!arrayUseTracker.isSatisfied(d)) {
log.info(" Not satisfied: {}", d);
}
}
}
if (l2 != null) {
for (Pair<InferenceSession.Dep, InferenceSession.Dep> d : l2) {
if (!arrayUseTracker.isSatisfied(d.getFirst()) && !arrayUseTracker.isSatisfied(d.getSecond())) {
log.info(" Not satisfied: {}", d);
}
}
}
}
}
}
if (numNotClosed > 0) {
System.out.println(sd.summary());
throw new IllegalStateException(numNotClosed + " arrays were not released but should have been");
}
}
protected Set<INDArray> identitySetAllConstPhVar() {
Set<INDArray> set = Collections.newSetFromMap(new IdentityHashMap<INDArray, Boolean>());
for (SDVariable v : sd.variables()) {
if (v.getVariableType() == VariableType.VARIABLE || v.getVariableType() == VariableType.CONSTANT || v.getVariableType() == VariableType.PLACEHOLDER) {
set.add(v.getArr());
}
}
return set;
}
}

View File

@ -0,0 +1,42 @@
package org.nd4j.autodiff.samediff.internal.memory;
import lombok.NonNull;
import org.nd4j.autodiff.samediff.internal.SessionMemMgr;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.factory.Nd4j;
/**
* A simple "no-op" memory manager that relies on JVM garbage collector for memory management.
* Assuming other references have been cleared (they should have been) the arrays will be cleaned up by the
* garbage collector at some point.
*
* This memory management strategy is not recommended for performance or memory reasons, and should only be used
* for testing and debugging purposes
*
* @author Alex Black
*/
public class NoOpMemoryMgr extends AbstractMemoryMgr implements SessionMemMgr {
@Override
public INDArray allocate(boolean detached, DataType dataType, long... shape) {
return Nd4j.createUninitialized(dataType, shape);
}
@Override
public INDArray allocate(boolean detached, LongShapeDescriptor descriptor) {
return Nd4j.create(descriptor, false);
}
@Override
public void release(@NonNull INDArray array) {
//No-op, rely on GC to clear arrays
}
@Override
public void close() {
//No-op
}
}

View File

@ -90,10 +90,10 @@ public class SDNN extends SDOps {
} }
/** /**
* @see #biasAdd(String, SDVariable, SDVariable) * @see #biasAdd(String, SDVariable, SDVariable, boolean)
*/ */
public SDVariable biasAdd(SDVariable input, SDVariable bias) { public SDVariable biasAdd(SDVariable input, SDVariable bias, boolean nchw) {
return biasAdd(null, input, bias); return biasAdd(null, input, bias, nchw);
} }
/** /**
@ -102,12 +102,14 @@ public class SDNN extends SDOps {
* @param name Name of the output variable * @param name Name of the output variable
* @param input 4d input variable * @param input 4d input variable
* @param bias 1d bias * @param bias 1d bias
* @param nchw The format - nchw=true means [minibatch, channels, height, width] format; nchw=false - [minibatch, height, width, channels].
* Unused for 2d inputs
* @return Output variable * @return Output variable
*/ */
public SDVariable biasAdd(String name, SDVariable input, SDVariable bias) { public SDVariable biasAdd(String name, SDVariable input, SDVariable bias, boolean nchw) {
validateFloatingPoint("biasAdd", "input", input); validateFloatingPoint("biasAdd", "input", input);
validateFloatingPoint("biasAdd", "bias", bias); validateFloatingPoint("biasAdd", "bias", bias);
SDVariable ret = f().biasAdd(input, bias); SDVariable ret = f().biasAdd(input, bias, nchw);
return updateVariableNameAndReference(ret, name); return updateVariableNameAndReference(ret, name);
} }

View File

@ -16,6 +16,7 @@
package org.nd4j.autodiff.samediff.serde; package org.nd4j.autodiff.samediff.serde;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Ints;
import com.google.flatbuffers.FlatBufferBuilder; import com.google.flatbuffers.FlatBufferBuilder;
import java.nio.ByteOrder; import java.nio.ByteOrder;
@ -847,6 +848,28 @@ public class FlatBuffersMapper {
} }
int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes); int outTypesOffset = FlatNode.createOutputTypesVector(bufferBuilder, outTypes);
//Control dependencies:
SameDiffOp sdo = sameDiff.getOps().get(node.getOwnName());
int opCds = 0;
int[] opCdsArr = mapOrNull(sdo.getControlDeps(), bufferBuilder);
if(opCdsArr != null){
opCds = FlatNode.createControlDepsVector(bufferBuilder, opCdsArr);
}
int varCds = 0;
int[] varCdsArr = mapOrNull(sdo.getVarControlDeps(), bufferBuilder);
if(varCdsArr != null){
varCds = FlatNode.createVarControlDepsVector(bufferBuilder, varCdsArr);
}
int cdsFor = 0;
int[] cdsForArr = mapOrNull(sdo.getControlDepFor(), bufferBuilder);
if(cdsForArr != null){
cdsFor = FlatNode.createControlDepForVector(bufferBuilder, cdsForArr);
}
int flatNode = FlatNode.createFlatNode( int flatNode = FlatNode.createFlatNode(
bufferBuilder, bufferBuilder,
ownId, ownId,
@ -867,12 +890,26 @@ public class FlatBuffersMapper {
outVarNamesOffset, outVarNamesOffset,
opNameOffset, opNameOffset,
outTypesOffset, //Output types outTypesOffset, //Output types
scalar scalar,
opCds,
varCds,
cdsFor
); );
return flatNode; return flatNode;
} }
public static int[] mapOrNull(List<String> list, FlatBufferBuilder fbb){
if(list == null)
return null;
int[] out = new int[list.size()];
int i=0;
for(String s : list){
out[i++] = fbb.createString(s);
}
return out;
}
public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){ public static DifferentialFunction cloneViaSerialize(SameDiff sd, DifferentialFunction df ){
Map<String,Integer> nameToIdxMap = new HashMap<>(); Map<String,Integer> nameToIdxMap = new HashMap<>();
int count = 0; int count = 0;

View File

@ -131,12 +131,12 @@ public class GradCheckUtil {
// in this case, gradients of x and y are all 0 too // in this case, gradients of x and y are all 0 too
//Collect variables to get gradients for - we want placeholders AND variables //Collect variables to get gradients for - we want placeholders AND variables
Set<String> gradVarNames = new HashSet<>(); Set<String> varsNeedingGrads = new HashSet<>();
for(Variable v : sd.getVariables().values()){ for(Variable v : sd.getVariables().values()){
if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){ if(v.getVariable().dataType().isFPType() && (v.getVariable().getVariableType() == VariableType.VARIABLE || v.getVariable().getVariableType() == VariableType.PLACEHOLDER)){
SDVariable g = v.getVariable().getGradient(); SDVariable g = v.getVariable().getGradient();
Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable()); Preconditions.checkNotNull(g, "No gradient variable found for variable %s", v.getVariable());
gradVarNames.add(g.getVarName()); varsNeedingGrads.add(v.getName());
} }
} }
@ -164,7 +164,7 @@ public class GradCheckUtil {
} }
sd.execBackwards(placeholderValues, new ArrayList<>(gradVarNames)); Map<String,INDArray> gm = sd.calculateGradients(placeholderValues, varsNeedingGrads);
//Remove listener, to reduce overhead //Remove listener, to reduce overhead
sd.getListeners().remove(listenerIdx); sd.getListeners().remove(listenerIdx);
@ -183,11 +183,11 @@ public class GradCheckUtil {
if(g == null){ if(g == null){
throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\""); throw new IllegalStateException("Null gradient variable for \"" + v.getVarName() + "\"");
} }
INDArray ga = g.getArr(); INDArray ga = gm.get(v.getVarName());
if(ga == null){ if(ga == null){
throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName()); throw new IllegalStateException("Null gradient array encountered for variable: " + v.getVarName());
} }
if(!Arrays.equals(v.getArr().shape(), g.getArr().shape())){ if(!Arrays.equals(v.getArr().shape(), ga.shape())){
throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" + throw new IllegalStateException("Gradient shape does not match variable shape for variable \"" +
v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " + v.getVarName() + "\": shape " + Arrays.toString(v.getArr().shape()) + " vs. gradient shape " +
Arrays.toString(ga.shape())); Arrays.toString(ga.shape()));
@ -408,18 +408,18 @@ public class GradCheckUtil {
//Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations //Collect names of variables to get gradients for - i.e., the names of the GRADIENT variables for the specified activations
sd.createGradFunction(); sd.createGradFunction();
Set<String> gradVarNames = new HashSet<>(); Set<String> varsRequiringGrads = new HashSet<>();
for(String s : actGrads){ for(String s : actGrads){
SDVariable grad = sd.getVariable(s).gradient(); SDVariable grad = sd.getVariable(s).gradient();
Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s); Preconditions.checkState( grad != null,"Could not get gradient for activation \"%s\": gradient variable is null", s);
gradVarNames.add(grad.getVarName()); varsRequiringGrads.add(s);
} }
//Calculate analytical gradients //Calculate analytical gradients
sd.execBackwards(config.getPlaceholderValues(), new ArrayList<>(gradVarNames)); Map<String,INDArray> grads = sd.calculateGradients(config.getPlaceholderValues(), new ArrayList<>(varsRequiringGrads));
Map<String,INDArray> gradientsForAct = new HashMap<>(); Map<String,INDArray> gradientsForAct = new HashMap<>();
for(String s : actGrads){ for(String s : actGrads){
INDArray arr = sd.getVariable(s).gradient().getArr(); INDArray arr = grads.get(s);
Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s); Preconditions.checkState(arr != null, "No activation gradient array for variable \"%s\"", s);
gradientsForAct.put(s, arr.dup()); gradientsForAct.put(s, arr.dup());
} }

View File

@ -190,11 +190,13 @@ public class OpValidation {
//Check forward pass: //Check forward pass:
if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) { if (testCase.fwdTestFns() != null && testCase.fwdTestFns().size() > 0) {
SameDiff sd = testCase.sameDiff(); SameDiff sd = testCase.sameDiff();
//Collect variables we need outputs for...
Set<String> reqVars = testCase.fwdTestFns().keySet();
Map<String,INDArray> out;
try { try {
if(testCase.placeholderValues() != null){ out = sd.output(testCase.placeholderValues(), new ArrayList<>(reqVars));
sd.resolveVariablesWith(testCase.placeholderValues());
}
sd.exec(null, sd.outputs());
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e); throw new RuntimeException("Error during forward pass testing" + testCase.testNameErrMsg(), e);
} }
@ -206,7 +208,7 @@ public class OpValidation {
e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg()); e.getKey() + "\" but SameDiff instance does not have a variable for this name" + testCase.testNameErrMsg());
} }
INDArray actual = v.getArr(); INDArray actual = out.get(v.getVarName());
if (actual == null) { if (actual == null) {
throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\""); throw new IllegalStateException("Null INDArray after forward pass for variable \"" + e.getKey() + "\"");
} }
@ -291,6 +293,12 @@ public class OpValidation {
Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps()); Preconditions.checkState((orig.getControlDeps() == null) == (des.getControlDeps() == null), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps()); Preconditions.checkState(orig.getControlDeps() == null || orig.getControlDeps().equals(des.getControlDeps()), "Control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
Preconditions.checkState((orig.getVarControlDeps() == null) == (des.getVarControlDeps() == null), "Op variable control dependencies differ: %s vs. %s", orig.getVarControlDeps(), des.getVarControlDeps());
Preconditions.checkState(orig.getVarControlDeps() == null || orig.getVarControlDeps().equals(des.getVarControlDeps()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDeps(), des.getControlDeps());
Preconditions.checkState((orig.getControlDepFor() == null) == (des.getControlDepFor() == null), "Op control dependencies for list differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
Preconditions.checkState(orig.getControlDepFor() == null || orig.getControlDepFor().equals(des.getControlDepFor()), "Op variable control dependencies differ: %s vs. %s", orig.getControlDepFor(), des.getControlDepFor());
Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass()); Preconditions.checkState(orig.getOp().getClass() == des.getOp().getClass(), "Classes differ: %s v. %s", orig.getOp().getClass(), des.getOp().getClass());
} }
@ -317,6 +325,11 @@ public class OpValidation {
Map<String,Variable> varsBefore = original.getVariables(); Map<String,Variable> varsBefore = original.getVariables();
Map<String,Variable> varsAfter = deserialized.getVariables(); Map<String,Variable> varsAfter = deserialized.getVariables();
Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet()); Preconditions.checkState(varsBefore.keySet().equals(varsAfter.keySet()), "Variable keysets do not match: %s vs %s", varsBefore.keySet(), varsAfter.keySet());
// System.out.println(original.summary());
// System.out.println("\n\n\n\n");
// System.out.println(deserialized.summary());
for(String s : varsBefore.keySet()){ for(String s : varsBefore.keySet()){
Variable vB = varsBefore.get(s); Variable vB = varsBefore.get(s);
Variable vA = varsAfter.get(s); Variable vA = varsAfter.get(s);
@ -324,13 +337,15 @@ public class OpValidation {
Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(), Preconditions.checkState(vB.getVariable().getVariableType() == vA.getVariable().getVariableType(),
"Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType()); "Variable types do not match: %s - %s vs %s", s, vB.getVariable().getVariableType(), vA.getVariable().getVariableType());
Preconditions.checkState((vB.getInputsForOp() == null) == (vA.getInputsForOp() == null), "Input to ops differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp()); equalConsideringNull(vB.getInputsForOp(), vA.getInputsForOp(), "%s - Input to ops differ: %s vs. %s", s, vB.getInputsForOp(), vA.getInputsForOp());
Preconditions.checkState(vB.getInputsForOp() == null || vB.getInputsForOp().equals(vA.getInputsForOp()), "Inputs differ: %s vs. %s", vB.getInputsForOp(), vA.getInputsForOp());
Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "Output of op differ: %s vs. %s", vB.getOutputOfOp(), vA.getOutputOfOp()); Preconditions.checkState((vB.getOutputOfOp() == null && vA.getOutputOfOp() == null) || vB.getOutputOfOp().equals(vA.getOutputOfOp()), "%s - Output of op differ: %s vs. %s", s, vB.getOutputOfOp(), vA.getOutputOfOp());
Preconditions.checkState((vB.getControlDeps() == null) == (vA.getControlDeps() == null), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps()); equalConsideringNull(vB.getControlDeps(), vA.getControlDeps(), "%s - Control dependencies differ: %s vs. %s", s, vB.getControlDeps(), vA.getControlDeps());
Preconditions.checkState(vB.getControlDeps() == null || vB.getControlDeps().equals(vA.getControlDeps()), "Control dependencies differ: %s vs. %s", vB.getControlDeps(), vA.getControlDeps());
equalConsideringNull(vB.getControlDepsForOp(), vA.getControlDepsForOp(), "%s - Control dependencies for ops differ: %s vs. %s", s, vB.getControlDepsForOp(), vA.getControlDepsForOp());
equalConsideringNull(vB.getControlDepsForVar(), vA.getControlDepsForVar(), "%s - Control dependencies for vars differ: %s vs. %s", s, vB.getControlDepsForVar(), vA.getControlDepsForVar());
} }
//Check loss variables: //Check loss variables:
@ -343,51 +358,62 @@ public class OpValidation {
lossVarBefore, lossVarAfter); lossVarBefore, lossVarAfter);
} }
if(tc.fwdTestFns() != null && !tc.fwdTestFns().isEmpty()) {
//Finally: check execution/output
Map<String,INDArray> outOrig = original.outputAll(tc.placeholderValues());
Map<String,INDArray> outDe = deserialized.outputAll(tc.placeholderValues());
Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model");
//Finally: check execution/output for (String s : outOrig.keySet()) {
Map<String,INDArray> outOrig = original.outputAll(tc.placeholderValues()); INDArray orig = outOrig.get(s);
Map<String,INDArray> outDe = deserialized.outputAll(tc.placeholderValues()); INDArray deser = outDe.get(s);
Preconditions.checkState(outOrig.keySet().equals(outDe.keySet()), "Keysets for execution after deserialization does not match key set for original model");
for(String s : outOrig.keySet()){ Function<INDArray, String> f = tc.fwdTestFns().get(s);
INDArray orig = outOrig.get(s); String err = null;
INDArray deser = outDe.get(s); if (f != null) {
err = f.apply(deser);
Function<INDArray,String> f = tc.fwdTestFns().get(s); } else {
String err = null; if (!orig.equals(deser)) {
if(f != null){ //Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op)
err = f.apply(deser); long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1;
} else { if (orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)) {
if(!orig.equals(deser)){ long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue();
//Edge case: check for NaNs in original and deserialized... might be legitimate test (like replaceNaNs op) if (count != count2) {
long count = orig.dataType().isNumerical() ? Nd4j.getExecutioner().execAndReturn(new MatchCondition(orig, Conditions.isNan())).getFinalResult().longValue() : -1; err = "INDArray equality failed";
if(orig.dataType().isNumerical() && count > 0 && orig.equalShapes(deser)){ } else {
long count2 = Nd4j.getExecutioner().execAndReturn(new MatchCondition(deser, Conditions.isNan())).getFinalResult().longValue(); //TODO is there a better way to do this?
if(count != count2){ NdIndexIterator iter = new NdIndexIterator(orig.shape());
err = "INDArray equality failed"; while (iter.hasNext()) {
} else { long[] i = iter.next();
//TODO is there a better way to do this? double d1 = orig.getDouble(i);
NdIndexIterator iter = new NdIndexIterator(orig.shape()); double d2 = deser.getDouble(i);
while(iter.hasNext()){ if ((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5) {
long[] i = iter.next(); err = "INDArray equality failed";
double d1 = orig.getDouble(i); break;
double d2 = deser.getDouble(i); }
if((Double.isNaN(d1) != Double.isNaN(d2)) || (Double.isInfinite(d1) != Double.isInfinite(d2)) || Math.abs(d1 - d2) > 1e-5 ){
err = "INDArray equality failed";
break;
} }
} }
} else {
err = "INDArray equality failed";
} }
} else {
err = "INDArray equality failed";
} }
} }
}
Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err); Preconditions.checkState(err == null, "Variable result (%s) failed check - \"%ndSInfo\" vs \"%ndSInfo\" - %nd10 vs %nd10\nError:%s", s, orig, deser, orig, deser, err);
}
} }
} }
protected static void equalConsideringNull(List<String> l1, List<String> l2, String msg, Object... args){
//Consider null and length 0 list to be equal (semantically they mean the same thing)
boolean empty1 = l1 == null || l1.isEmpty();
boolean empty2 = l2 == null || l2.isEmpty();
if(empty1 && empty2){
return;
}
Preconditions.checkState(l1 == null || l1.equals(l2), msg, args);
}
/** /**
* Validate the outputs of a single op * Validate the outputs of a single op
* *

View File

@ -25,6 +25,7 @@ public class NonInplaceValidationListener extends BaseListener {
private static AtomicInteger failCounter = new AtomicInteger(); private static AtomicInteger failCounter = new AtomicInteger();
protected INDArray[] opInputs; protected INDArray[] opInputs;
protected INDArray[] opInputsOrig;
public NonInplaceValidationListener(){ public NonInplaceValidationListener(){
useCounter.getAndIncrement(); useCounter.getAndIncrement();
@ -42,14 +43,18 @@ public class NonInplaceValidationListener extends BaseListener {
//No input op //No input op
return; return;
} else if(o.y() == null){ } else if(o.y() == null){
opInputsOrig = new INDArray[]{o.x()};
opInputs = new INDArray[]{o.x().dup()}; opInputs = new INDArray[]{o.x().dup()};
} else { } else {
opInputsOrig = new INDArray[]{o.x(), o.y()};
opInputs = new INDArray[]{o.x().dup(), o.y().dup()}; opInputs = new INDArray[]{o.x().dup(), o.y().dup()};
} }
} else if(op.getOp() instanceof DynamicCustomOp){ } else if(op.getOp() instanceof DynamicCustomOp){
INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments(); INDArray[] arr = ((DynamicCustomOp) op.getOp()).inputArguments();
opInputs = new INDArray[arr.length]; opInputs = new INDArray[arr.length];
opInputsOrig = new INDArray[arr.length];
for( int i=0; i<arr.length; i++ ){ for( int i=0; i<arr.length; i++ ){
opInputsOrig[i] = arr[i];
opInputs[i] = arr[i].dup(); opInputs[i] = arr[i].dup();
} }
} else { } else {
@ -64,23 +69,6 @@ public class NonInplaceValidationListener extends BaseListener {
return; return;
} }
INDArray[] inputsAfter;
if(op.getOp() instanceof Op){
Op o = (Op)op.getOp();
if(o.x() == null){
//No input op
return;
} else if(o.y() == null){
inputsAfter = new INDArray[]{o.x()};
} else {
inputsAfter = new INDArray[]{o.x(), o.y()};
}
} else if(op.getOp() instanceof DynamicCustomOp){
inputsAfter = ((DynamicCustomOp) op.getOp()).inputArguments();
} else {
throw new IllegalStateException("Unknown op type: " + op.getOp().getClass());
}
MessageDigest md; MessageDigest md;
try { try {
md = MessageDigest.getInstance("MD5"); md = MessageDigest.getInstance("MD5");
@ -93,12 +81,12 @@ public class NonInplaceValidationListener extends BaseListener {
//Need to hash - to ensure zero changes to input array //Need to hash - to ensure zero changes to input array
byte[] before = opInputs[i].data().asBytes(); byte[] before = opInputs[i].data().asBytes();
INDArray after = inputsAfter[i]; INDArray after = this.opInputsOrig[i];
boolean dealloc = false; boolean dealloc = false;
if(opInputs[i].ordering() != inputsAfter[i].ordering() || Arrays.equals(opInputs[i].stride(), inputsAfter[i].stride()) if(opInputs[i].ordering() != opInputsOrig[i].ordering() || Arrays.equals(opInputs[i].stride(), opInputsOrig[i].stride())
|| opInputs[i].elementWiseStride() != inputsAfter[i].elementWiseStride()){ || opInputs[i].elementWiseStride() != opInputsOrig[i].elementWiseStride()){
//Clone if required (otherwise fails for views etc) //Clone if required (otherwise fails for views etc)
after = inputsAfter[i].dup(); after = opInputsOrig[i].dup();
dealloc = true; dealloc = true;
} }
byte[] afterB = after.data().asBytes(); byte[] afterB = after.data().asBytes();

View File

@ -67,29 +67,41 @@ public final class FlatNode extends Table {
public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); } public ByteBuffer outputTypesInByteBuffer(ByteBuffer _bb) { return __vector_in_bytebuffer(_bb, 38, 1); }
public FlatArray scalar() { return scalar(new FlatArray()); } public FlatArray scalar() { return scalar(new FlatArray()); }
public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public FlatArray scalar(FlatArray obj) { int o = __offset(40); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
public String controlDeps(int j) { int o = __offset(42); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsLength() { int o = __offset(42); return o != 0 ? __vector_len(o) : 0; }
public String varControlDeps(int j) { int o = __offset(44); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int varControlDepsLength() { int o = __offset(44); return o != 0 ? __vector_len(o) : 0; }
public String controlDepFor(int j) { int o = __offset(46); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForLength() { int o = __offset(46); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatNode(FlatBufferBuilder builder, public static int createFlatNode(FlatBufferBuilder builder,
int id, int id,
int nameOffset, int nameOffset,
byte opType, byte opType,
long opNum, long opNum,
int propertiesOffset, int propertiesOffset,
int inputOffset, int inputOffset,
int inputPairedOffset, int inputPairedOffset,
int outputOffset, int outputOffset,
int extraParamsOffset, int extraParamsOffset,
int extraIntegerOffset, int extraIntegerOffset,
int extraBoolsOffset, int extraBoolsOffset,
int dimensionsOffset, int dimensionsOffset,
int device, int device,
int scope_id, int scope_id,
int scope_nameOffset, int scope_nameOffset,
int outputNamesOffset, int outputNamesOffset,
int opNameOffset, int opNameOffset,
int outputTypesOffset, int outputTypesOffset,
int scalarOffset) { int scalarOffset,
builder.startObject(19); int controlDepsOffset,
int varControlDepsOffset,
int controlDepForOffset) {
builder.startObject(22);
FlatNode.addOpNum(builder, opNum); FlatNode.addOpNum(builder, opNum);
FlatNode.addControlDepFor(builder, controlDepForOffset);
FlatNode.addVarControlDeps(builder, varControlDepsOffset);
FlatNode.addControlDeps(builder, controlDepsOffset);
FlatNode.addScalar(builder, scalarOffset); FlatNode.addScalar(builder, scalarOffset);
FlatNode.addOutputTypes(builder, outputTypesOffset); FlatNode.addOutputTypes(builder, outputTypesOffset);
FlatNode.addOpName(builder, opNameOffset); FlatNode.addOpName(builder, opNameOffset);
@ -111,7 +123,7 @@ public final class FlatNode extends Table {
return FlatNode.endFlatNode(builder); return FlatNode.endFlatNode(builder);
} }
public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(19); } public static void startFlatNode(FlatBufferBuilder builder) { builder.startObject(22); }
public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); } public static void addId(FlatBufferBuilder builder, int id) { builder.addInt(0, id, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); } public static void addOpType(FlatBufferBuilder builder, byte opType) { builder.addByte(2, opType, 0); }
@ -151,6 +163,15 @@ public final class FlatNode extends Table {
public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); } public static int createOutputTypesVector(FlatBufferBuilder builder, byte[] data) { builder.startVector(1, data.length, 1); for (int i = data.length - 1; i >= 0; i--) builder.addByte(data[i]); return builder.endVector(); }
public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); } public static void startOutputTypesVector(FlatBufferBuilder builder, int numElems) { builder.startVector(1, numElems, 1); }
public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); } public static void addScalar(FlatBufferBuilder builder, int scalarOffset) { builder.addOffset(18, scalarOffset, 0); }
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(19, controlDepsOffset, 0); }
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addVarControlDeps(FlatBufferBuilder builder, int varControlDepsOffset) { builder.addOffset(20, varControlDepsOffset, 0); }
public static int createVarControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startVarControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepFor(FlatBufferBuilder builder, int controlDepForOffset) { builder.addOffset(21, controlDepForOffset, 0); }
public static int createControlDepForVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endFlatNode(FlatBufferBuilder builder) { public static int endFlatNode(FlatBufferBuilder builder) {
int o = builder.endObject(); int o = builder.endObject();
return o; return o;

View File

@ -29,16 +29,28 @@ public final class FlatVariable extends Table {
public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; } public FlatArray ndarray(FlatArray obj) { int o = __offset(12); return o != 0 ? obj.__assign(__indirect(o + bb_pos), bb) : null; }
public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; } public int device() { int o = __offset(14); return o != 0 ? bb.getInt(o + bb_pos) : 0; }
public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; } public byte variabletype() { int o = __offset(16); return o != 0 ? bb.get(o + bb_pos) : 0; }
public String controlDeps(int j) { int o = __offset(18); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsLength() { int o = __offset(18); return o != 0 ? __vector_len(o) : 0; }
public String controlDepForOp(int j) { int o = __offset(20); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepForOpLength() { int o = __offset(20); return o != 0 ? __vector_len(o) : 0; }
public String controlDepsForVar(int j) { int o = __offset(22); return o != 0 ? __string(__vector(o) + j * 4) : null; }
public int controlDepsForVarLength() { int o = __offset(22); return o != 0 ? __vector_len(o) : 0; }
public static int createFlatVariable(FlatBufferBuilder builder, public static int createFlatVariable(FlatBufferBuilder builder,
int idOffset, int idOffset,
int nameOffset, int nameOffset,
byte dtype, byte dtype,
int shapeOffset, int shapeOffset,
int ndarrayOffset, int ndarrayOffset,
int device, int device,
byte variabletype) { byte variabletype,
builder.startObject(7); int controlDepsOffset,
int controlDepForOpOffset,
int controlDepsForVarOffset) {
builder.startObject(10);
FlatVariable.addControlDepsForVar(builder, controlDepsForVarOffset);
FlatVariable.addControlDepForOp(builder, controlDepForOpOffset);
FlatVariable.addControlDeps(builder, controlDepsOffset);
FlatVariable.addDevice(builder, device); FlatVariable.addDevice(builder, device);
FlatVariable.addNdarray(builder, ndarrayOffset); FlatVariable.addNdarray(builder, ndarrayOffset);
FlatVariable.addShape(builder, shapeOffset); FlatVariable.addShape(builder, shapeOffset);
@ -49,7 +61,7 @@ public final class FlatVariable extends Table {
return FlatVariable.endFlatVariable(builder); return FlatVariable.endFlatVariable(builder);
} }
public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(7); } public static void startFlatVariable(FlatBufferBuilder builder) { builder.startObject(10); }
public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); } public static void addId(FlatBufferBuilder builder, int idOffset) { builder.addOffset(0, idOffset, 0); }
public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); } public static void addName(FlatBufferBuilder builder, int nameOffset) { builder.addOffset(1, nameOffset, 0); }
public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); } public static void addDtype(FlatBufferBuilder builder, byte dtype) { builder.addByte(2, dtype, 0); }
@ -59,6 +71,15 @@ public final class FlatVariable extends Table {
public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); } public static void addNdarray(FlatBufferBuilder builder, int ndarrayOffset) { builder.addOffset(4, ndarrayOffset, 0); }
public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); } public static void addDevice(FlatBufferBuilder builder, int device) { builder.addInt(5, device, 0); }
public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); } public static void addVariabletype(FlatBufferBuilder builder, byte variabletype) { builder.addByte(6, variabletype, 0); }
public static void addControlDeps(FlatBufferBuilder builder, int controlDepsOffset) { builder.addOffset(7, controlDepsOffset, 0); }
public static int createControlDepsVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepForOp(FlatBufferBuilder builder, int controlDepForOpOffset) { builder.addOffset(8, controlDepForOpOffset, 0); }
public static int createControlDepForOpVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepForOpVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static void addControlDepsForVar(FlatBufferBuilder builder, int controlDepsForVarOffset) { builder.addOffset(9, controlDepsForVarOffset, 0); }
public static int createControlDepsForVarVector(FlatBufferBuilder builder, int[] data) { builder.startVector(4, data.length, 4); for (int i = data.length - 1; i >= 0; i--) builder.addOffset(data[i]); return builder.endVector(); }
public static void startControlDepsForVarVector(FlatBufferBuilder builder, int numElems) { builder.startVector(4, numElems, 4); }
public static int endFlatVariable(FlatBufferBuilder builder) { public static int endFlatVariable(FlatBufferBuilder builder) {
int o = builder.endObject(); int o = builder.endObject();
return o; return o;
@ -67,3 +88,4 @@ public final class FlatVariable extends Table {
public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); } public static void finishSizePrefixedFlatVariableBuffer(FlatBufferBuilder builder, int offset) { builder.finishSizePrefixed(offset); }
} }

View File

@ -25,11 +25,7 @@ import org.nd4j.imports.descriptors.onnx.OnnxDescriptorParser;
import org.nd4j.imports.descriptors.onnx.OpDescriptor; import org.nd4j.imports.descriptors.onnx.OpDescriptor;
import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser; import org.nd4j.imports.descriptors.tensorflow.TensorflowDescriptorParser;
import org.nd4j.linalg.api.ops.*; import org.nd4j.linalg.api.ops.*;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter; import org.nd4j.linalg.api.ops.impl.controlflow.compat.*;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Merge;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.NextIteration;
import org.nd4j.linalg.api.ops.impl.controlflow.compat.Switch;
import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction; import org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction;
import org.nd4j.linalg.api.ops.impl.layers.convolution.*; import org.nd4j.linalg.api.ops.impl.layers.convolution.*;
import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.exception.ND4JIllegalStateException;
@ -370,6 +366,8 @@ public class DifferentialFunctionClassHolder {
return Merge.class; return Merge.class;
case Switch.OP_NAME: case Switch.OP_NAME:
return Switch.class; return Switch.class;
case LoopCond.OP_NAME:
return LoopCond.class;
case ExternalErrorsFunction.OP_NAME: case ExternalErrorsFunction.OP_NAME:
return ExternalErrorsFunction.class; return ExternalErrorsFunction.class;
default: default:

View File

@ -69,13 +69,9 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan.class,
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThanOrEqual.class,
org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class, org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastNotEqual.class,
org.nd4j.linalg.api.ops.impl.controlflow.If.class,
org.nd4j.linalg.api.ops.impl.controlflow.IfDerivative.class,
org.nd4j.linalg.api.ops.impl.controlflow.Select.class, org.nd4j.linalg.api.ops.impl.controlflow.Select.class,
org.nd4j.linalg.api.ops.impl.controlflow.Where.class, org.nd4j.linalg.api.ops.impl.controlflow.Where.class,
org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class, org.nd4j.linalg.api.ops.impl.controlflow.WhereNumpy.class,
org.nd4j.linalg.api.ops.impl.controlflow.While.class,
org.nd4j.linalg.api.ops.impl.controlflow.WhileDerivative.class,
org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Enter.class,
org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.Exit.class,
org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class, org.nd4j.linalg.api.ops.impl.controlflow.compat.LoopCond.class,

View File

@ -1,413 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.protobuf.TextFormat;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.apache.commons.io.IOUtils;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.VariableType;
import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.autodiff.samediff.internal.Variable;
import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import java.io.*;
import java.util.*;
/**
* Base implementation for importing a graph
*
* @param <GRAPH_TYPE> the type of graph
* @param <NODE_TYPE> the type of node
* @param <ATTR_TYPE> the attribute type
*/
@Slf4j
public abstract class BaseGraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> implements GraphMapper<GRAPH_TYPE, NODE_TYPE, ATTR_TYPE, TENSOR_TYPE> {
@Override
public Op.Type opTypeForNode(NODE_TYPE nodeDef) {
DifferentialFunction opWithTensorflowName = getMappedOp(getOpType(nodeDef));
if (opWithTensorflowName == null)
throw new NoOpNameFoundException("No op found with name " + getOpType(nodeDef));
Op.Type type = opWithTensorflowName.opType();
return type;
}
@Override
public void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings) {
val mappings = propertyMappings.get(getOpType(node));
if (mappings == null || mappings.isEmpty()) {
return;
}
for (val entry : mappings.entrySet()) {
mapProperty(entry.getKey(), on, node, graph, sameDiff, propertyMappings);
}
}
/**
* @param inputStream
* @return
*/
@Override
public SameDiff importGraph(InputStream inputStream) {
return importGraph(inputStream, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
}
@Override
public SameDiff importGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
GRAPH_TYPE def = readGraph(inputStream, opImportOverrides);
return importGraph(def, opImportOverrides, opFilter);
}
protected GRAPH_TYPE readGraph(InputStream inputStream, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides) {
byte[] bytes = null;
GRAPH_TYPE def = null;
try {
bytes = IOUtils.toByteArray(inputStream); //Buffers internally
def = parseGraphFrom(bytes);
} catch (IOException e) {
try (BufferedInputStream bis2 = new BufferedInputStream(new ByteArrayInputStream(bytes)); BufferedReader reader = new BufferedReader(new InputStreamReader(bis2))) {
Message.Builder builder = getNewGraphBuilder();
StringBuilder str = new StringBuilder();
String line = null;
while ((line = reader.readLine()) != null) {
str.append(line);//.append("\n");
}
TextFormat.getParser().merge(str.toString(), builder);
def = (GRAPH_TYPE) builder.build();
} catch (Exception e2) {
e2.printStackTrace();
}
}
return def;
}
/**
* @param graphFile
* @return
*/
@Override
public SameDiff importGraph(File graphFile) {
return importGraph(graphFile, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
}
@Override
public SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
GRAPH_TYPE def = null;
try (FileInputStream fis = new FileInputStream(graphFile)) {
return importGraph(fis, opImportOverrides, opFilter);
} catch (Exception e) {
throw new ND4JIllegalStateException("Error encountered loading graph file: " + graphFile.getAbsolutePath(), e);
}
}
@Override
public Map<String, NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph) {
List<NODE_TYPE> nodes = getNodeList(graph);
Map<String, NODE_TYPE> ret = new HashMap<>();
for (NODE_TYPE node : nodes) {
ret.put(getName(node), node);
}
return ret;
}
@Override
public Map<String, NODE_TYPE> nodesByName(GRAPH_TYPE graph) {
val nodeTypes = getNodeList(graph);
Map<String, NODE_TYPE> ret = new LinkedHashMap<>();
for (int i = 0; i < nodeTypes.size(); i++) {
ret.put(getName(nodeTypes.get(i)), nodeTypes.get(i));
}
return ret;
}
/**
* This method converts given TF
*
* @param tfGraph
* @return
*/
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph) {
return importGraph(tfGraph, Collections.<String, OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>>emptyMap(), null);
}
@Override
public SameDiff importGraph(GRAPH_TYPE tfGraph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter) {
SameDiff diff = SameDiff.create();
ImportState<GRAPH_TYPE, TENSOR_TYPE> importState = new ImportState<>();
importState.setSameDiff(diff);
importState.setGraph(tfGraph);
Map<String,TENSOR_TYPE> variablesForGraph = variablesForGraph(tfGraph);
importState.setVariables(variablesForGraph);
//Add each of the variables first - before importing ops
Map<String, Boolean> stringNodes = new HashMap<>(); //Key: name of string variable. Value: if it's a constant
for (Map.Entry<String, TENSOR_TYPE> entry : variablesForGraph.entrySet()) {
if (shouldSkip((NODE_TYPE) entry.getValue())) { //TODO only works for TF
//Skip some nodes, for example reduction indices (a lot of ND4J/SameDiff ops use int[] etc, not an INDArray/Variable)
continue;
}
//First: check if we're skipping the op entirely. If so: don't create the output variables for it.
NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
String opType = getOpType(node);
String opName = getName(node);
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, importState.getGraph() )){
log.info("Skipping variables for op: {} (name: {})", opType, opName);
continue;
}
//Similarly, if an OpImportOverride is defined, don't create the variables now, as these might be the wrong type
//For example, the OpImportOverride might replace the op with some placeholders
// If we simply created them now, we'd create the wrong type (Array not placeholder)
if(opImportOverrides != null && opImportOverrides.containsKey(opType)){
log.info("Skipping variables for op due to presence of OpImportOverride: {} (name: {})", opType, opName);
continue;
}
DataType dt = dataTypeForTensor(entry.getValue(), 0);
INDArray arr = getNDArrayFromTensor(entry.getKey(), entry.getValue(), tfGraph);
long[] shape = hasShape((NODE_TYPE) entry.getValue()) ? getShape((NODE_TYPE) entry.getValue()) : null; //TODO only works for TF
//Not all variables have datatypes available on import - we have to infer these at a later point
// so we'll leave datatypes as null and infer them once all variables/ops have been imported
if(dt == DataType.UNKNOWN)
dt = null;
if (isPlaceHolder(entry.getValue())) {
diff.placeHolder(entry.getKey(), dt, shape);
} else if (isConstant(entry.getValue())) {
Preconditions.checkNotNull(arr, "Array is null for placeholder variable %s", entry.getKey());
diff.constant(entry.getKey(), arr);
} else {
//Could be variable, or could be array type (i.e., output of op/"activations")
//TODO work out which!
SDVariable v;
if(shape == null || ArrayUtil.contains(shape, 0)){
//No shape, or 0 in shape -> probably not a variable...
v = diff.var(entry.getKey(), VariableType.ARRAY, null, dt, (long[])null);
} else {
v = diff.var(entry.getKey(), dt, shape);
}
if (arr != null)
diff.associateArrayWithVariable(arr, v);
}
// NODE_TYPE node = (NODE_TYPE) entry.getValue(); //TODO this only works for TF
List<String> controlDependencies = getControlDependencies(node);
if (controlDependencies != null) {
Variable v = diff.getVariables().get(entry.getKey());
v.setControlDeps(controlDependencies);
}
}
//Map ops
val tfNodesList = getNodeList(tfGraph);
for (NODE_TYPE node : tfNodesList) {
String opType = getOpType(node);
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> importOverride = null;
if(opImportOverrides != null){
importOverride = opImportOverrides.get(opType);
}
if(opFilter != null && opFilter.skipOp(node, importState.getSameDiff(), null, null)){
String opName = getName(node);
log.info("Skipping op due to op filter: {}", opType, opName);
continue;
}
if (!opsToIgnore().contains(opType) || isOpIgnoreException(node)) {
mapNodeType(node, importState, importOverride, opFilter);
}
}
/*
At this point, we have a few remaining things to do:
1. Make sure all datatypes are set on all variables. TF doesn't have datatype info an all op outputs for some reason, so we have to infer in manually
2. Make sure all op output variables have been created
3. Make sure all SameDiffOp.outputsOfOp is set
4. Make sure all Variable.outputOfOp is set
5. Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
*/
//Make sure Variable.outputOfOp is set
for(Variable v : diff.getVariables().values()){
if(v.getVariable().isPlaceHolder() || v.getVariable().isConstant())
continue;
//Expect variable names of output variables to be: opName, opName:1, opName:2, etc
String n = v.getName();
String opName = n;
if(v.getName().matches(".*:\\d+")){
//i.e., "something:2"
int idx = n.lastIndexOf(':');
opName = n.substring(0,idx);
}
if(diff.getOps().containsKey(opName)) {
//Variable is the output of an op
v.setOutputOfOp(opName);
//Also double check variable type...
if(v.getVariable().getVariableType() != VariableType.ARRAY)
v.getVariable().setVariableType(VariableType.ARRAY);
}
}
//Initialize any missing output variables
for (SameDiffOp op : diff.getOps().values()) {
DifferentialFunction df = op.getOp();
initOutputVariables(diff, df);
}
//Make sure all Variable.controlDepsForVar have been populated (reverse lookup of Variable.controlDeps)
//i.e., if control dependency x -> y exists, then:
// (a) x.controlDepsForVar should contain "y"
// (b) y.controlDeps should contain "x"
//Need to do this before output datatype calculation, as these control dep info is used in sessions
for(Map.Entry<String,Variable> e : diff.getVariables().entrySet()){
Variable v = e.getValue();
if(v.getControlDeps() != null){
for(String s : v.getControlDeps()){
Variable v2 = diff.getVariables().get(s);
if(v2.getControlDepsForVar() == null)
v2.setControlDepsForVar(new ArrayList<String>());
if(!v2.getControlDepsForVar().contains(e.getKey())){
//Control dep v2 -> v exists, so put v.name into v2.controlDepsForVar
v2.getControlDepsForVar().add(e.getKey());
}
}
}
}
//Same thing for op control dependencies...
for(Map.Entry<String,SameDiffOp> e : diff.getOps().entrySet()){
SameDiffOp op = e.getValue();
if(op.getControlDeps() != null){
for(String s : op.getControlDeps()){
//Control dependency varS -> op exists
Variable v = diff.getVariables().get(s);
if(v.getControlDepsForOp() == null)
v.setControlDepsForOp(new ArrayList<String>());
if(!v.getControlDepsForOp().contains(e.getKey()))
v.getControlDepsForOp().add(e.getKey());
}
}
}
//Infer variable datatypes to ensure all variables have datatypes...
boolean anyUnknown = false;
for(SDVariable v : diff.variables()){
if(v.dataType() == null)
anyUnknown = true;
}
if(anyUnknown){
Map<String,DataType> dataTypes = diff.calculateOutputDataTypes();
for(SDVariable v : diff.variables()){
if(v.dataType() == null){
v.setDataType(dataTypes.get(v.getVarName()));
}
}
}
//Validate the graph structure
validateGraphStructure(diff);
return diff;
}
protected void initOutputVariables(SameDiff sd, DifferentialFunction df) {
String[] outNames = sd.getOutputsForOp(df);
SDVariable[] outVars;
if (outNames == null) {
outVars = sd.generateOutputVariableForOp(df, df.getOwnName() != null ? df.getOwnName() : df.opName(), true);
outNames = new String[outVars.length];
for (int i = 0; i < outVars.length; i++) {
outNames[i] = outVars[i].getVarName();
}
sd.getOps().get(df.getOwnName()).setOutputsOfOp(Arrays.asList(outNames));
}
for (String s : outNames) {
sd.getVariables().get(s).setOutputOfOp(df.getOwnName());
}
}
@Override
public boolean validTensorDataType(TENSOR_TYPE tensorType) {
return dataTypeForTensor(tensorType, 0) != DataType.UNKNOWN;
}
public void validateGraphStructure(SameDiff sameDiff) {
//First: Check placeholders. When SDVariables are added with null shapes, these can be interpreted as a placeholder
// but null shapes might simply mean shape isn't available during import right when the variable is added
//Idea here: if a "placeholder" is the output of any function, it's not really a placeholder
for (SDVariable v : sameDiff.variables()) {
String name = v.getVarName();
if (sameDiff.isPlaceHolder(name)) {
String varOutputOf = sameDiff.getVariables().get(name).getOutputOfOp();
}
}
//Second: check that all op inputs actually exist in the graph
for (SameDiffOp op : sameDiff.getOps().values()) {
List<String> inputs = op.getInputsToOp();
if (inputs == null)
continue;
for (String s : inputs) {
if (sameDiff.getVariable(s) == null) {
throw new IllegalStateException("Import validation failed: op \"" + op.getName() + "\" of type " + op.getOp().getClass().getSimpleName()
+ " has input \"" + s + "\" that does not have a corresponding variable in the graph");
}
}
}
}
}

View File

@ -1,429 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Map graph proto types to
*
* {@link SameDiff} instances
* @param <GRAPH_TYPE> the proto type for the graph
* @param <NODE_TYPE> the proto type for the node
* @param <ATTR_TYPE> the proto type for the attribute
* @param <TENSOR_TYPE> the proto type for the tensor
*@author Adam Gibson
*/
public interface GraphMapper<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE,TENSOR_TYPE> {
/**
* Import a graph as SameDiff from the given file
* @param graphFile Input stream pointing to graph file to import
* @return Imported graph
*/
SameDiff importGraph(InputStream graphFile);
SameDiff importGraph(InputStream graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
/**
* Import a graph as SameDiff from the given file
* @param graphFile Graph file to import
* @return Imported graph
* @see #importGraph(File, Map)
*/
SameDiff importGraph(File graphFile);
/**
* Import a graph as SameDiff from the given file, with optional op import overrides.<br>
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
*
* @param graphFile Graph file to import
* @param opImportOverrides May be null. If non-null: used to import the specified operations. Key is the name of the
* operation to import, value is the object used to import it
* @return Imported graph
*/
SameDiff importGraph(File graphFile, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
/**
* This method converts given graph type (in its native format) to SameDiff
* @param graph Graph to import
* @return Imported graph
*/
SameDiff importGraph(GRAPH_TYPE graph);
/**
* This method converts given graph type (in its native format) to SameDiff<br>
* The {@link OpImportOverride} instances allow the operation import to be overridden - useful for importing ops
* that have not been mapped for import in SameDiff yet, and also for non-standard/user-defined functions.
* @param graph Graph to import
* @return Imported graph
*/
SameDiff importGraph(GRAPH_TYPE graph, Map<String,? extends OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE>> opImportOverrides,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
/**
* Returns true if this node is a special case
* (maybe because of name or other scenarios)
* that should override {@link #opsToIgnore()}
* in certain circumstances
* @param node the node to check
* @return true if this node is an exception false otherwise
*/
boolean isOpIgnoreException(NODE_TYPE node);
/**
* Get the nodes sorted by n ame
* from a given graph
* @param graph the graph to get the nodes for
* @return the map of the nodes by name
* for a given graph
*/
Map<String,NODE_TYPE> nodesByName(GRAPH_TYPE graph);
/**
* Get the target mapping key (usually based on the node name)
* for the given function
* @param function the function
* @param node the node to derive the target mapping from
* @return
*/
String getTargetMappingForOp(DifferentialFunction function, NODE_TYPE node);
/**
*
* @param on
* @param node
* @param graph
* @param sameDiff
* @param propertyMappings
*/
void mapProperties(DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappings);
/**
*
* @param name
* @param on
* @param node
* @param graph
* @param sameDiff
* @param propertyMappingsForFunction
*/
void mapProperty(String name, DifferentialFunction on, NODE_TYPE node, GRAPH_TYPE graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction);
/**
* Get the node from the graph
* @param graph the graph to get the node from
* @param name the name of the node to get from the graph
* @return
*/
NODE_TYPE getNodeWithNameFromGraph(GRAPH_TYPE graph,String name);
/**
* Returns true if the given node is a place holder
* @param node the node to check
* @return true if the node is a place holder or not
*/
boolean isPlaceHolderNode(TENSOR_TYPE node);
/**
* Get the list of control dependencies for the current node (or null if none exist)
*
* @param node Node to get the control dependencies (if any) for
* @return
*/
List<String> getControlDependencies(NODE_TYPE node);
/**
* Dump a binary proto file representation as a
* plain string in to the target text file
* @param inputFile
* @param outputFile
*/
void dumpBinaryProtoAsText(File inputFile,File outputFile);
/**
* Dump a binary proto file representation as a
* plain string in to the target text file
* @param inputFile
* @param outputFile
*/
void dumpBinaryProtoAsText(InputStream inputFile,File outputFile);
/**
* Get the mapped op name
* for a given op
* relative to the type of node being mapped.
* The input name should be based on a tensorflow
* type or onnx type, not the nd4j name
* @param name the tensorflow or onnx name
* @return the function based on the values in
* {@link org.nd4j.imports.converters.DifferentialFunctionClassHolder}
*/
DifferentialFunction getMappedOp(String name);
/**
* Get the variables for the given graph
* @param graphType the graph to get the variables for
* @return a map of variable name to tensor
*/
Map<String,TENSOR_TYPE> variablesForGraph(GRAPH_TYPE graphType);
/**
*
* @param name
* @param node
* @return
*/
String translateToSameDiffName(String name, NODE_TYPE node);
/**
*
* @param graph
* @return
*/
Map<String,NODE_TYPE> nameIndexForGraph(GRAPH_TYPE graph);
/**
* Returns an op type for the given input node
* @param nodeType the node to use
* @return the optype for the given node
*/
Op.Type opTypeForNode(NODE_TYPE nodeType);
/**
* Returns a graph builder for initial definition and parsing.
* @return
*/
Message.Builder getNewGraphBuilder();
/**
* Parse a graph from an input stream
* @param inputStream the input stream to load from
* @return
*/
GRAPH_TYPE parseGraphFrom(byte[] inputStream) throws IOException;
/**
* Parse a graph from an input stream
* @param inputStream the input stream to load from
* @return
*/
GRAPH_TYPE parseGraphFrom(InputStream inputStream) throws IOException;
/**
* Map a node in to the import state covering the {@link SameDiff} instance
* @param tfNode the node to map
* @param importState the current import state
* @param opFilter Optional filter for skipping operations
*/
void mapNodeType(NODE_TYPE tfNode, ImportState<GRAPH_TYPE,TENSOR_TYPE> importState,
OpImportOverride<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opImportOverride,
OpImportFilter<GRAPH_TYPE,NODE_TYPE,ATTR_TYPE> opFilter);
/**
*
* @param tensorType
* @param outputNum
* @return
*/
DataType dataTypeForTensor(TENSOR_TYPE tensorType, int outputNum);
boolean isStringType(TENSOR_TYPE tensor);
/**
*
* @param nodeType
* @param key
* @return
*/
String getAttrValueFromNode(NODE_TYPE nodeType,String key);
/**
*
* @param attrType
* @return
*/
long[] getShapeFromAttribute(ATTR_TYPE attrType);
/**
* Returns true if the given node is a place holder type
* (think a yet to be determined shape)_
* @param nodeType
* @return
*/
boolean isPlaceHolder(TENSOR_TYPE nodeType);
/**
* Returns true if the given node is a constant
* @param nodeType
* @return
*/
boolean isConstant(TENSOR_TYPE nodeType);
/**
*
*
* @param tensorName
* @param tensorType
* @param graph
* @return
*/
INDArray getNDArrayFromTensor(String tensorName, TENSOR_TYPE tensorType, GRAPH_TYPE graph);
/**
* Get the shape for the given tensor type
* @param tensorType
* @return
*/
long[] getShapeFromTensor(TENSOR_TYPE tensorType);
/**
* Ops to ignore for mapping
* @return
*/
Set<String> opsToIgnore();
/**
* Get the input node for the given node
* @param node the node
* @param index hte index
* @return
*/
String getInputFromNode(NODE_TYPE node, int index);
/**
* Get the number of inputs for a node.
* @param nodeType the node to get the number of inputs for
* @return
*/
int numInputsFor(NODE_TYPE nodeType);
/**
* Whether the data type for the tensor is valid
* for creating an {@link INDArray}
* @param tensorType the tensor proto to test
* @return
*/
boolean validTensorDataType(TENSOR_TYPE tensorType);
/**
* Get the shape of the attribute value
* @param attr the attribute value
* @return the shape of the attribute if any or null
*/
long[] getShapeFromAttr(ATTR_TYPE attr);
/**
* Get the attribute
* map for given node
* @param nodeType the node
* @return the attribute map for the attribute
*/
Map<String,ATTR_TYPE> getAttrMap(NODE_TYPE nodeType);
/**
* Get the name of the node
* @param nodeType the node
* to get the name for
* @return
*/
String getName(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
boolean alreadySeen(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
boolean isVariableNode(NODE_TYPE nodeType);
/**
*
*
* @param opType
* @return
*/
boolean shouldSkip(NODE_TYPE opType);
/**
*
* @param nodeType
* @return
*/
boolean hasShape(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @return
*/
long[] getShape(NODE_TYPE nodeType);
/**
*
* @param nodeType
* @param graph
* @return
*/
INDArray getArrayFrom(NODE_TYPE nodeType, GRAPH_TYPE graph);
String getOpType(NODE_TYPE nodeType);
/**
*
* @param graphType
* @return
*/
List<NODE_TYPE> getNodeList(GRAPH_TYPE graphType);
}

View File

@ -1,31 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper;
import lombok.Data;
import org.nd4j.autodiff.samediff.SameDiff;
import java.util.Map;
@Data
public class ImportState<GRAPH_TYPE,TENSOR_TYPE> {
private SameDiff sameDiff;
private GRAPH_TYPE graph;
private Map<String,TENSOR_TYPE> variables;
}

View File

@ -1,652 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.imports.graphmapper.onnx;
import org.nd4j.shade.protobuf.ByteString;
import org.nd4j.shade.protobuf.Message;
import org.nd4j.shade.guava.primitives.Floats;
import org.nd4j.shade.guava.primitives.Ints;
import org.nd4j.shade.guava.primitives.Longs;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.BaseGraphMapper;
import org.nd4j.imports.graphmapper.ImportState;
import org.nd4j.imports.graphmapper.OpImportFilter;
import org.nd4j.imports.graphmapper.OpImportOverride;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.*;
/**
* A mapper for onnx graphs to
* {@link org.nd4j.autodiff.samediff.SameDiff} instances.
*
* @author Adam Gibson
*/
public class OnnxGraphMapper extends BaseGraphMapper<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto, onnx.Onnx.TypeProto.Tensor> {
private static OnnxGraphMapper INSTANCE = new OnnxGraphMapper();
public static OnnxGraphMapper getInstance() {
return INSTANCE;
}
@Override
public void dumpBinaryProtoAsText(InputStream inputFile, File outputFile) {
try {
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(inputFile);
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString() + "\n");
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* Init a function's attributes
* @param mappedTfName the onnx name to pick (sometimes ops have multiple names
* @param on the function to map
* @param attributesForNode the attributes for the node
* @param node
* @param graph
*/
public void initFunctionFromProperties(String mappedTfName, DifferentialFunction on, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.NodeProto node, Onnx.GraphProto graph) {
val properties = on.mappingsForFunction();
val tfProperties = properties.get(mappedTfName);
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
val attributeAdapters = on.attributeAdaptersForFunction();
for(val entry : tfProperties.entrySet()) {
val tfAttrName = entry.getValue().getTfAttrName();
val currentField = fields.get(entry.getKey());
AttributeAdapter adapter = null;
if(tfAttrName != null) {
if(currentField == null) {
continue;
}
if(attributeAdapters != null && !attributeAdapters.isEmpty()) {
val mappers = attributeAdapters.get(on.tensorflowName());
val adapterFor = mappers.get(entry.getKey());
adapter = adapterFor;
}
if(attributesForNode.containsKey(tfAttrName)) {
val attr = attributesForNode.get(tfAttrName);
switch (attr.getType()) {
case STRING:
val setString = attr.getS().toStringUtf8();
if(adapter != null) {
adapter.mapAttributeFor(setString,currentField,on);
}
else
on.setValueFor(currentField,setString);
break;
case INT:
val setInt = (int) attr.getI();
if(adapter != null) {
adapter.mapAttributeFor(setInt,currentField,on);
}
else
on.setValueFor(currentField,setInt);
break;
case INTS:
val setList = attr.getIntsList();
if(!setList.isEmpty()) {
val intList = Ints.toArray(setList);
if(adapter != null) {
adapter.mapAttributeFor(intList,currentField,on);
}
else
on.setValueFor(currentField,intList);
}
break;
case FLOATS:
val floatsList = attr.getFloatsList();
if(!floatsList.isEmpty()) {
val floats = Floats.toArray(floatsList);
if(adapter != null) {
adapter.mapAttributeFor(floats,currentField,on);
}
else
on.setValueFor(currentField,floats);
break;
}
break;
case TENSOR:
val tensorToGet = mapTensorProto(attr.getT());
if(adapter != null) {
adapter.mapAttributeFor(tensorToGet,currentField,on);
}
else
on.setValueFor(currentField,tensorToGet);
break;
}
}
}
}
}
@Override
public boolean isOpIgnoreException(Onnx.NodeProto node) {
return false;
}
@Override
public String getTargetMappingForOp(DifferentialFunction function, Onnx.NodeProto node) {
return function.opName();
}
@Override
public void mapProperty(String name, DifferentialFunction on, Onnx.NodeProto node, Onnx.GraphProto graph, SameDiff sameDiff, Map<String, Map<String, PropertyMapping>> propertyMappingsForFunction) {
val mapping = propertyMappingsForFunction.get(name).get(getTargetMappingForOp(on, node));
val fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(on);
/**
* Map ints and the like. Need to figure out how attribute mapping should work.
*
*
*/
val propsForFunction = on.propertiesForFunction();
if(mapping.getTfAttrName() == null) {
int tfMappingIdx = mapping.getTfInputPosition();
if(tfMappingIdx < 0)
tfMappingIdx += node.getInputCount();
val input = node.getInput(tfMappingIdx);
val inputNode = getInstance().getNodeWithNameFromGraph(graph,input);
INDArray arr = sameDiff.getArrForVarName(input);
val field = fields.get(mapping.getPropertyNames()[0]);
val type = field.getType();
if(type.equals(int[].class)) {
try {
field.set(arr.data().asInt(),on);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
else if(type.equals(int.class) || type.equals(long.class) || type.equals(Long.class) || type.equals(Integer.class)) {
try {
field.set(arr.getInt(0),on);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
else if(type.equals(float.class) || type.equals(double.class) || type.equals(Float.class) || type.equals(Double.class)) {
try {
field.set(arr.getDouble(0),on);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
/**
* Figure out whether it's an int array
* or a double array, or maybe a scalar.
*/
}
else {
val tfMappingAttrName = mapping.getOnnxAttrName();
val attr = getAttrMap(node).get(tfMappingAttrName);
val type = attr.getType();
val field = fields.get(mapping.getPropertyNames()[0]);
Object valueToSet = null;
switch(type) {
case INT:
valueToSet = attr.getI();
break;
case FLOAT:
valueToSet = attr.getF();
break;
case STRING:
valueToSet = attr.getF();
break;
}
try {
field.set(valueToSet,on);
} catch (IllegalAccessException e) {
e.printStackTrace();
}
}
}
@Override
public Onnx.NodeProto getNodeWithNameFromGraph(Onnx.GraphProto graph, String name) {
for(int i = 0; i < graph.getNodeCount(); i++) {
val node = graph.getNode(i);
if(node.getName().equals(name))
return node;
}
return null;
}
@Override
public boolean isPlaceHolderNode(Onnx.TypeProto.Tensor node) {
return false;
}
@Override
public List<String> getControlDependencies(Onnx.NodeProto node) {
throw new UnsupportedOperationException("Not yet implemented");
}
@Override
public void dumpBinaryProtoAsText(File inputFile, File outputFile) {
try {
Onnx.ModelProto graphDef = Onnx.ModelProto.parseFrom(new BufferedInputStream(new FileInputStream(inputFile)));
BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(outputFile,true));
for(Onnx.NodeProto node : graphDef.getGraph().getNodeList()) {
bufferedWriter.write(node.toString());
}
bufferedWriter.flush();
bufferedWriter.close();
} catch (IOException e) {
e.printStackTrace();
}
}
/**
*
* @param name the tensorflow or onnx name
* @return
*/
@Override
public DifferentialFunction getMappedOp(String name) {
return DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(name);
}
@Override
public Map<String,onnx.Onnx.TypeProto.Tensor> variablesForGraph(Onnx.GraphProto graphProto) {
/**
* Need to figure out why
* gpu_0/conv1_1 isn't present in VGG
*/
Map<String,onnx.Onnx.TypeProto.Tensor> ret = new HashMap<>();
for(int i = 0; i < graphProto.getInputCount(); i++) {
ret.put(graphProto.getInput(i).getName(),graphProto.getInput(i).getType().getTensorType());
}
for(int i = 0; i < graphProto.getOutputCount(); i++) {
ret.put(graphProto.getOutput(i).getName(),graphProto.getOutput(i).getType().getTensorType());
}
for(int i = 0; i < graphProto.getNodeCount(); i++) {
val node = graphProto.getNode(i);
val name = node.getName().isEmpty() ? String.valueOf(i) : node.getName();
//add -1 as place holder value representing the shape needs to be filled in
if(!ret.containsKey(name)) {
addDummyTensor(name,ret);
}
for(int j = 0; j < node.getInputCount(); j++) {
if(!ret.containsKey(node.getInput(j))) {
addDummyTensor(node.getInput(j),ret);
}
}
for(int j = 0; j < node.getOutputCount(); j++) {
if(!ret.containsKey(node.getOutput(j))) {
addDummyTensor(node.getOutput(j),ret);
}
}
}
return ret;
}
@Override
public String translateToSameDiffName(String name, Onnx.NodeProto node) {
return null;
}
protected void addDummyTensor(String name, Map<String, Onnx.TypeProto.Tensor> to) {
Onnx.TensorShapeProto.Dimension dim = Onnx.TensorShapeProto.Dimension.
newBuilder()
.setDimValue(-1)
.build();
Onnx.TypeProto.Tensor typeProto = Onnx.TypeProto.Tensor.newBuilder()
.setShape(
Onnx.TensorShapeProto.newBuilder()
.addDim(dim)
.addDim(dim).build())
.build();
to.put(name,typeProto);
}
@Override
public Message.Builder getNewGraphBuilder() {
return Onnx.GraphProto.newBuilder();
}
@Override
public Onnx.GraphProto parseGraphFrom(byte[] inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public Onnx.GraphProto parseGraphFrom(InputStream inputStream) throws IOException {
return Onnx.ModelProto.parseFrom(inputStream).getGraph();
}
@Override
public void mapNodeType(Onnx.NodeProto tfNode, ImportState<Onnx.GraphProto, Onnx.TypeProto.Tensor> importState,
OpImportOverride<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opImportOverride,
OpImportFilter<Onnx.GraphProto, Onnx.NodeProto, Onnx.AttributeProto> opFilter) {
val differentialFunction = DifferentialFunctionClassHolder.getInstance().getOpWithOnnxName(tfNode.getOpType());
if(differentialFunction == null) {
throw new NoOpNameFoundException("No op name found " + tfNode.getOpType());
}
val diff = importState.getSameDiff();
val idx = importState.getGraph().getNodeList().indexOf(tfNode);
val name = !tfNode.getName().isEmpty() ? tfNode.getName() : String.valueOf(idx);
try {
val newInstance = differentialFunction.getClass().newInstance();
val args = new SDVariable[tfNode.getInputCount()];
newInstance.setSameDiff(importState.getSameDiff());
newInstance.initFromOnnx(tfNode,diff,getAttrMap(tfNode),importState.getGraph());
importState.getSameDiff().putOpForId(newInstance.getOwnName(),newInstance);
//ensure we can track node name to function instance later.
diff.setBaseNameForFunctionInstanceId(tfNode.getName(),newInstance);
//diff.addVarNameForImport(tfNode.getName());
}
catch (Exception e) {
e.printStackTrace();
}
}
@Override
public DataType dataTypeForTensor(Onnx.TypeProto.Tensor tensorProto, int outputNum) {
return nd4jTypeFromOnnxType(tensorProto.getElemType());
}
@Override
public boolean isStringType(Onnx.TypeProto.Tensor tensor) {
return tensor.getElemType() == Onnx.TensorProto.DataType.STRING;
}
/**
* Convert an onnx type to the proper nd4j type
* @param dataType the data type to convert
* @return the nd4j type for the onnx type
*/
public DataType nd4jTypeFromOnnxType(Onnx.TensorProto.DataType dataType) {
switch (dataType) {
case DOUBLE: return DataType.DOUBLE;
case FLOAT: return DataType.FLOAT;
case FLOAT16: return DataType.HALF;
case INT32:
case INT64: return DataType.INT;
default: return DataType.UNKNOWN;
}
}
@Override
public String getAttrValueFromNode(Onnx.NodeProto nodeProto, String key) {
for(Onnx.AttributeProto attributeProto : nodeProto.getAttributeList()) {
if(attributeProto.getName().equals(key)) {
return attributeProto.getS().toString();
}
}
throw new ND4JIllegalStateException("No key found for " + key);
}
@Override
public long[] getShapeFromAttribute(Onnx.AttributeProto attributeProto) {
return Longs.toArray(attributeProto.getT().getDimsList());
}
@Override
public boolean isPlaceHolder(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public boolean isConstant(Onnx.TypeProto.Tensor nodeType) {
return false;
}
@Override
public INDArray getNDArrayFromTensor(String tensorName, Onnx.TypeProto.Tensor tensorProto, Onnx.GraphProto graph) {
DataType type = dataTypeForTensor(tensorProto, 0);
if(!tensorProto.isInitialized()) {
throw new ND4JIllegalStateException("Unable to retrieve ndarray. Tensor was not initialized");
}
Onnx.TensorProto tensor = null;
for(int i = 0; i < graph.getInitializerCount(); i++) {
val initializer = graph.getInitializer(i);
if(initializer.getName().equals(tensorName)) {
tensor = initializer;
break;
}
}
if(tensor == null)
return null;
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
long[] shape = getShapeFromTensor(tensorProto);
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
public INDArray mapTensorProto(Onnx.TensorProto tensor) {
if(tensor == null)
return null;
DataType type = nd4jTypeFromOnnxType(tensor.getDataType());
ByteString bytes = tensor.getRawData();
ByteBuffer byteBuffer = bytes.asReadOnlyByteBuffer().order(ByteOrder.nativeOrder());
ByteBuffer directAlloc = ByteBuffer.allocateDirect(byteBuffer.capacity()).order(ByteOrder.nativeOrder());
directAlloc.put(byteBuffer);
directAlloc.rewind();
long[] shape = getShapeFromTensor(tensor);
DataBuffer buffer = Nd4j.createBuffer(directAlloc,type, ArrayUtil.prod(shape));
INDArray arr = Nd4j.create(buffer).reshape(shape);
return arr;
}
@Override
public long[] getShapeFromTensor(onnx.Onnx.TypeProto.Tensor tensorProto) {
val ret = new long[Math.max(2,tensorProto.getShape().getDimCount())];
int dimCount = tensorProto.getShape().getDimCount();
if(dimCount >= 2)
for(int i = 0; i < ret.length; i++) {
ret[i] = (int) tensorProto.getShape().getDim(i).getDimValue();
}
else {
ret[0] = 1;
for(int i = 1; i < ret.length; i++) {
ret[i] = (int) tensorProto.getShape().getDim(i - 1).getDimValue();
}
}
return ret;
}
/**
* Get the shape from a tensor proto.
* Note that this is different from {@link #getShapeFromTensor(Onnx.TensorProto)}
* @param tensorProto the tensor to get the shape from
* @return
*/
public long[] getShapeFromTensor(Onnx.TensorProto tensorProto) {
val ret = new long[Math.max(2,tensorProto.getDimsCount())];
int dimCount = tensorProto.getDimsCount();
if(dimCount >= 2)
for(int i = 0; i < ret.length; i++) {
ret[i] = (int) tensorProto.getDims(i);
}
else {
ret[0] = 1;
for(int i = 1; i < ret.length; i++) {
ret[i] = (int) tensorProto.getDims(i - 1);
}
}
return ret;
}
@Override
public Set<String> opsToIgnore() {
return Collections.emptySet();
}
@Override
public String getInputFromNode(Onnx.NodeProto node, int index) {
return node.getInput(index);
}
@Override
public int numInputsFor(Onnx.NodeProto nodeProto) {
return nodeProto.getInputCount();
}
@Override
public long[] getShapeFromAttr(Onnx.AttributeProto attr) {
return Longs.toArray(attr.getT().getDimsList());
}
@Override
public Map<String, Onnx.AttributeProto> getAttrMap(Onnx.NodeProto nodeProto) {
Map<String,Onnx.AttributeProto> proto = new HashMap<>();
for(int i = 0; i < nodeProto.getAttributeCount(); i++) {
Onnx.AttributeProto attributeProto = nodeProto.getAttribute(i);
proto.put(attributeProto.getName(),attributeProto);
}
return proto;
}
@Override
public String getName(Onnx.NodeProto nodeProto) {
return nodeProto.getName();
}
@Override
public boolean alreadySeen(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public boolean isVariableNode(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType().contains("Var");
}
@Override
public boolean shouldSkip(Onnx.NodeProto opType) {
return false;
}
@Override
public boolean hasShape(Onnx.NodeProto nodeProto) {
return false;
}
@Override
public long[] getShape(Onnx.NodeProto nodeProto) {
return null;
}
@Override
public INDArray getArrayFrom(Onnx.NodeProto nodeProto, Onnx.GraphProto graph) {
return null;
}
@Override
public String getOpType(Onnx.NodeProto nodeProto) {
return nodeProto.getOpType();
}
@Override
public List<Onnx.NodeProto> getNodeList(Onnx.GraphProto graphProto) {
return graphProto.getNodeList();
}
}

View File

@ -226,22 +226,24 @@ public class TensorFlowImportValidator {
} }
public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException { public static TFImportStatus checkModelForImport(String path, InputStream is, boolean exceptionOnRead) throws IOException {
TFGraphMapper m = TFGraphMapper.getInstance();
try { try {
int opCount = 0; int opCount = 0;
Set<String> opNames = new HashSet<>(); Set<String> opNames = new HashSet<>();
try(InputStream bis = new BufferedInputStream(is)) { try(InputStream bis = new BufferedInputStream(is)) {
GraphDef graphDef = m.parseGraphFrom(bis); GraphDef graphDef = GraphDef.parseFrom(bis);
List<NodeDef> nodes = m.getNodeList(graphDef); List<NodeDef> nodes = new ArrayList<>(graphDef.getNodeCount());
for( int i=0; i<graphDef.getNodeCount(); i++ ){
nodes.add(graphDef.getNode(i));
}
if(nodes.isEmpty()){ if(nodes.isEmpty()){
throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path); throw new IllegalStateException("Error loading model for import - loaded graph def has no nodes (empty/corrupt file?): " + path);
} }
for (NodeDef nd : nodes) { for (NodeDef nd : nodes) {
if (m.isVariableNode(nd) || m.isPlaceHolderNode(nd)) if (TFGraphMapper.isVariableNode(nd) || TFGraphMapper.isPlaceHolder(nd))
continue; continue;
String op = nd.getOp(); String op = nd.getOp();

View File

@ -86,6 +86,7 @@ import java.io.*;
import java.nio.IntBuffer; import java.nio.IntBuffer;
import java.nio.LongBuffer; import java.nio.LongBuffer;
import java.util.*; import java.util.*;
import java.util.concurrent.atomic.AtomicLong;
import static org.nd4j.linalg.factory.Nd4j.*; import static org.nd4j.linalg.factory.Nd4j.*;
@ -124,6 +125,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
protected transient JvmShapeInfo jvmShapeInfo; protected transient JvmShapeInfo jvmShapeInfo;
private static final AtomicLong arrayCounter = new AtomicLong(0);
protected transient final long arrayId = arrayCounter.getAndIncrement();
//Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over //Precalculate these arrays (like [3,2,1,0], [2,1,0], [1,0], [0] etc) for use in TAD, to avoid creating same int[]s over and over
private static final int[][] tadFinalPermuteDimensions; private static final int[][] tadFinalPermuteDimensions;
@ -139,7 +143,6 @@ public abstract class BaseNDArray implements INDArray, Iterable {
} }
public BaseNDArray() { public BaseNDArray() {
} }
@Override @Override
@ -4916,6 +4919,8 @@ public abstract class BaseNDArray implements INDArray, Iterable {
@Override @Override
public String toString(@NonNull NDArrayStrings options){ public String toString(@NonNull NDArrayStrings options){
if(wasClosed())
return "<Closed NDArray, id=" + getId() + ", dtype=" + dataType() + ", shape=" + Arrays.toString(shape()) + ">";
if (!isCompressed() && !preventUnpack) if (!isCompressed() && !preventUnpack)
return options.format(this); return options.format(this);
else if (isCompressed() && compressDebug) else if (isCompressed() && compressDebug)
@ -5600,4 +5605,9 @@ public abstract class BaseNDArray implements INDArray, Iterable {
return false; return false;
} }
@Override
public long getId(){
return arrayId;
}
} }

View File

@ -2814,4 +2814,10 @@ public interface INDArray extends Serializable, AutoCloseable {
* @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int) * @see org.nd4j.linalg.api.ndarray.BaseNDArray#toString(long, boolean, int)
*/ */
String toStringFull(); String toStringFull();
/**
* A unique ID for the INDArray object instance. Does not account for views.
* @return INDArray unique ID
*/
long getId();
} }

View File

@ -24,6 +24,7 @@ import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction; import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataBuffer; import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -200,48 +201,17 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
@Override @Override
public void setX(INDArray x) { public void setX(INDArray x) {
if (x == null) { this.x = x;
if (args() != null && args().length >= 1) {
SDVariable firstArg = args()[0];
if (firstArg.getArr() != null)
this.x = firstArg.getArr();
} else
throw new ND4JIllegalStateException("Unable to set null array for x. Also unable to infer from differential function arguments");
} else
this.x = x;
} }
@Override @Override
public void setZ(INDArray z) { public void setZ(INDArray z) {
if (z == null) { this.z = z;
SDVariable getResult = sameDiff.getVariable(zVertexId);
if (getResult != null) {
if (getResult.getArr() != null)
this.z = getResult.getArr();
else if(sameDiff.getShapeForVarName(getResult.getVarName()) != null) {
val shape = sameDiff.getShapeForVarName(getResult.getVarName());
sameDiff.setArrayForVariable(getResult.getVarName(),getResult.getWeightInitScheme().create(getResult.dataType(), shape));
}
else
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
} else
throw new ND4JIllegalStateException("Unable to set null array for z. Also unable to infer from differential function arguments");
} else
this.z = z;
} }
@Override @Override
public void setY(INDArray y) { public void setY(INDArray y) {
if (y == null) { this.y = y;
if (args() != null && args().length > 1) {
SDVariable firstArg = args()[1];
if (firstArg.getArr() != null)
this.y = firstArg.getArr();
} else
throw new ND4JIllegalStateException("Unable to set null array for y. Also unable to infer from differential function arguments");
} else
this.y = y;
} }
@Override @Override
@ -265,6 +235,12 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
return z; return z;
} }
@Override
public INDArray getInputArgument(int index){
Preconditions.checkState(index >= 0 && index < 2, "Input argument index must be 0 or 1, got %s", index);
return index == 0 ? x : y;
}
@Override @Override
public SDVariable[] outputVariables(String baseName) { public SDVariable[] outputVariables(String baseName) {
if(zVertexId == null) { if(zVertexId == null) {
@ -403,4 +379,11 @@ public abstract class BaseOp extends DifferentialFunction implements Op {
//Always 1 for legacy/base ops //Always 1 for legacy/base ops
return 1; return 1;
} }
@Override
public void clearArrays(){
x = null;
y = null;
z = null;
}
} }

View File

@ -16,7 +16,6 @@
package org.nd4j.linalg.api.ops; package org.nd4j.linalg.api.ops;
import org.nd4j.shade.guava.primitives.Ints;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
@ -24,21 +23,14 @@ import lombok.val;
import onnx.Onnx; import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.api.shape.Shape; import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
@ -71,10 +63,6 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
this.keepDims = keepDims; this.keepDims = keepDims;
this.xVertexId = i_v.getVarName(); this.xVertexId = i_v.getVarName();
sameDiff.addArgsFor(new String[]{xVertexId},this); sameDiff.addArgsFor(new String[]{xVertexId},this);
if(Shape.isPlaceholderShape(i_v.getShape())) {
sameDiff.addPropertyToResolve(this,i_v.getVarName());
}
} else { } else {
throw new IllegalArgumentException("Input not null variable."); throw new IllegalArgumentException("Input not null variable.");
} }
@ -219,14 +207,7 @@ public abstract class BaseReduceOp extends BaseOp implements ReduceOp {
@Override @Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
if (!attributesForNode.containsKey("axes")) {
this.dimensions = new int[] { Integer.MAX_VALUE };
}
else {
val map = OnnxGraphMapper.getInstance().getAttrMap(node);
val dims = Ints.toArray(map.get("axes").getIntsList());
this.dimensions = dims;
}
} }
@Override @Override

View File

@ -119,4 +119,9 @@ public interface CustomOp {
* otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException} * otherwise throws an {@link org.nd4j.linalg.exception.ND4JIllegalStateException}
*/ */
void assertValidForExecution(); void assertValidForExecution();
/**
* Clear the input and output INDArrays, if any are set
*/
void clearArrays();
} }

View File

@ -263,7 +263,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
@Override @Override
public INDArray[] outputArguments() { public INDArray[] outputArguments() {
if (!outputArguments.isEmpty()) { if (!outputArguments.isEmpty()) {
return outputArguments.toArray(new INDArray[outputArguments.size()]); return outputArguments.toArray(new INDArray[0]);
} }
return new INDArray[0]; return new INDArray[0];
} }
@ -271,7 +271,7 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
@Override @Override
public INDArray[] inputArguments() { public INDArray[] inputArguments() {
if (!inputArguments.isEmpty()) if (!inputArguments.isEmpty())
return inputArguments.toArray(new INDArray[inputArguments.size()]); return inputArguments.toArray(new INDArray[0]);
return new INDArray[0]; return new INDArray[0];
} }
@ -389,6 +389,13 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
public void setInputArgument(int index, INDArray input) { public void setInputArgument(int index, INDArray input) {
if(index >= inputArguments.size() ){
List<INDArray> oldArgs = inputArguments;
inputArguments = new ArrayList<>(index+1);
inputArguments.addAll(oldArgs);
while(inputArguments.size() <= index)
inputArguments.add(null);
}
inputArguments.set(index, input); inputArguments.set(index, input);
} }
@ -400,12 +407,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
public void setOutputArgument(int index, INDArray output) { public void setOutputArgument(int index, INDArray output) {
if(index == outputArguments.size()){ while(index >= outputArguments.size()){
//For example, setOutputArgument(0,arr) on empty list //Resize list, in case we want to specify arrays not in order they are defined
outputArguments.add(output); //For example, index 1 on empty list, then index 0
} else { outputArguments.add(null);
outputArguments.set(index, output);
} }
outputArguments.set(index, output);
} }
@Override @Override
@ -608,6 +615,12 @@ public class DynamicCustomOp extends DifferentialFunction implements CustomOp {
} }
@Override
public void clearArrays(){
inputArguments.clear();
outputArguments.clear();
}
protected static INDArray[] wrapOrNull(INDArray in){ protected static INDArray[] wrapOrNull(INDArray in){
return in == null ? null : new INDArray[]{in}; return in == null ? null : new INDArray[]{in};
} }

View File

@ -167,4 +167,9 @@ public interface Op {
* @return the equivalent {@link CustomOp} * @return the equivalent {@link CustomOp}
*/ */
CustomOp toCustomOp(); CustomOp toCustomOp();
/**
* Clear the input and output INDArrays, if any are set
*/
void clearArrays();
} }

View File

@ -25,6 +25,6 @@ public class AdjustContrastV2 extends BaseAdjustContrast {
@Override @Override
public String tensorflowName() { public String tensorflowName() {
return "AdjustContrast"; return "AdjustContrastV2";
} }
} }

View File

@ -245,4 +245,9 @@ public class ScatterUpdate implements CustomOp {
public void assertValidForExecution() { public void assertValidForExecution() {
} }
@Override
public void clearArrays() {
op.clearArrays();
}
} }

View File

@ -39,13 +39,18 @@ import java.util.*;
@NoArgsConstructor @NoArgsConstructor
public class BiasAdd extends DynamicCustomOp { public class BiasAdd extends DynamicCustomOp {
protected boolean nchw = true;
public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias) { public BiasAdd(SameDiff sameDiff, SDVariable input, SDVariable bias, boolean nchw) {
super(null, sameDiff, new SDVariable[] {input, bias}, false); super(null, sameDiff, new SDVariable[] {input, bias}, false);
bArguments.clear();
bArguments.add(nchw);
} }
public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output){ public BiasAdd(@NonNull INDArray input, @NonNull INDArray bias, INDArray output, boolean nchw){
super(new INDArray[]{input, bias}, wrapOrNull(output)); super(new INDArray[]{input, bias}, wrapOrNull(output));
bArguments.clear();
bArguments.add(nchw);
} }
@Override @Override
@ -56,7 +61,11 @@ public class BiasAdd extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph); super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
if(attributesForNode.containsKey("data_format")){
nchw = "NCHW".equalsIgnoreCase(attributesForNode.get("data_format").getS().toStringUtf8());
}
bArguments.clear();
bArguments.add(nchw);
} }
@Override @Override

View File

@ -1,402 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.util.HashUtil;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.*;
/**
* Equivalent to tensorflow's conditional op.
* Runs one of 2 {@link SameDiff.SameDiffFunctionDefinition}
* depending on a predicate {@link org.nd4j.autodiff.samediff.SameDiff.SameDiffConditional}
*
*
* @author Adam Gibson
*/
@NoArgsConstructor
@Slf4j
public class If extends DifferentialFunction implements CustomOp {
@Getter
protected SameDiff loopBodyExecution,predicateExecution,falseBodyExecution;
@Getter
protected SameDiffConditional predicate;
@Getter
protected SameDiffFunctionDefinition trueBody,falseBody;
@Getter
protected String blockName,trueBodyName,falseBodyName;
@Getter
protected SDVariable[] inputVars;
@Getter
protected Boolean trueBodyExecuted = null;
@Getter
protected SDVariable targetBoolean;
protected SDVariable dummyResult;
@Getter
@Setter
protected SDVariable[] outputVars;
public If(If ifStatement) {
this.sameDiff = ifStatement.sameDiff;
this.outputVars = ifStatement.outputVars;
this.falseBodyExecution = ifStatement.falseBodyExecution;
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
this.falseBody = ifStatement.falseBody;
this.trueBodyExecuted = ifStatement.trueBodyExecuted;
this.dummyResult = ifStatement.dummyResult;
this.inputVars = ifStatement.inputVars;
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme(), DataType.FLOAT, 1);
if(sameDiff.getShapeForVarName(dummyResult.getVarName()) == null)
sameDiff.putShapeForVarName(dummyResult.getVarName(),new long[]{1,1});
}
@Builder
public If(String blockName,
SameDiff parent,
SDVariable[] inputVars,
SameDiffFunctionDefinition conditionBody,
SameDiffConditional predicate,
SameDiffFunctionDefinition trueBody,
SameDiffFunctionDefinition falseBody) {
this.sameDiff = parent;
parent.putOpForId(getOwnName(),this);
this.inputVars = inputVars;
this.predicate = predicate;
parent.addArgsFor(inputVars,this);
this.trueBody = trueBody;
this.falseBody = falseBody;
this.blockName = blockName;
//need to add the op to the list of ops to be executed when running backwards
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
//create a samediff sub graph for running just the execution
//return a reference to the loop for referencing during actual execution
SameDiff sameDiff = SameDiff.create();
//store the reference to the result array and the same diff execution instance
this.targetBoolean = predicate.eval(sameDiff,conditionBody, inputVars);
this.predicateExecution = sameDiff;
//store references to the loop body
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
this.trueBodyName = trueBodyName;
String falseBodyName = "false-body-" + UUID.randomUUID().toString();
this.falseBodyName = trueBodyName;
//running define function will setup a proper same diff instance
this.loopBodyExecution = parent.defineFunction(trueBodyName,trueBody,inputVars);
this.falseBodyExecution = parent.defineFunction(falseBodyName,falseBody,inputVars);
parent.defineFunction(blockName,conditionBody,inputVars);
parent.putSubFunction("predicate-eval-body-" + UUID.randomUUID().toString(),sameDiff);
//get a reference to the actual loop body
this.loopBodyExecution = parent.getFunction(trueBodyName);
}
/**
* Toggle whether the true body was executed
* or the false body
* @param trueBodyExecuted
*/
public void exectedTrueOrFalse(boolean trueBodyExecuted) {
if(trueBodyExecuted)
this.trueBodyExecuted = true;
else
this.trueBodyExecuted = false;
}
@Override
public SDVariable[] outputVariables(String baseName) {
return new SDVariable[]{dummyResult};
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
List<SDVariable> ret = new ArrayList<>();
ret.addAll(Arrays.asList(new IfDerivative(this).outputVariables()));
return ret;
}
@Override
public String toString() {
return opName();
}
@Override
public String opName() {
return "if";
}
@Override
public long opHash() {
return HashUtil.getLongHash(opName());
}
@Override
public boolean isInplaceCall() {
return false;
}
@Override
public INDArray[] outputArguments() {
return new INDArray[0];
}
@Override
public INDArray[] inputArguments() {
return new INDArray[0];
}
@Override
public long[] iArgs() {
return new long[0];
}
@Override
public double[] tArgs() {
return new double[0];
}
@Override
public boolean[] bArgs() {
return new boolean[0];
}
@Override
public void addIArgument(int... arg) {
}
@Override
public void addIArgument(long... arg) {
}
@Override
public void addBArgument(boolean... arg) {
}
@Override
public void removeIArgument(Integer arg) {
}
@Override
public Boolean getBArgument(int index) {
return null;
}
@Override
public Long getIArgument(int index) {
return null;
}
@Override
public int numIArguments() {
return 0;
}
@Override
public void addTArgument(double... arg) {
}
@Override
public void removeTArgument(Double arg) {
}
@Override
public Double getTArgument(int index) {
return null;
}
@Override
public int numTArguments() {
return 0;
}
@Override
public int numBArguments() {
return 0;
}
@Override
public void addInputArgument(INDArray... arg) {
}
@Override
public void removeInputArgument(INDArray arg) {
}
@Override
public INDArray getInputArgument(int index) {
return null;
}
@Override
public int numInputArguments() {
return 0;
}
@Override
public void addOutputArgument(INDArray... arg) {
}
@Override
public void removeOutputArgument(INDArray arg) {
}
@Override
public INDArray getOutputArgument(int index) {
return null;
}
@Override
public int numOutputArguments() {
return 0;
}
@Override
public Op.Type opType() {
return Op.Type.CONDITIONAL;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
//cond is only part of while loops
if(nodeDef.getName().contains("/cond/"))
return;
//usually should be a merge node for a conditional
val ifNodes = TFGraphMapper.getInstance().nodesForIf(nodeDef,graph);
val trueScopeGraphDefBuilder = GraphDef.newBuilder();
for(val node : ifNodes.getTrueNodes()) {
trueScopeGraphDefBuilder.addNode(node);
}
val trueScope = TFGraphMapper.getInstance().importGraph(trueScopeGraphDefBuilder.build());
val falseScopeGraphDefBuilder = GraphDef.newBuilder();
for(val node : ifNodes.getFalseNodes()) {
falseScopeGraphDefBuilder.addNode(node);
}
val falseScope = TFGraphMapper.getInstance().importGraph(falseScopeGraphDefBuilder.build());
val condScopeGraphDefBuilder = GraphDef.newBuilder();
for(val node : ifNodes.getCondNodes()) {
condScopeGraphDefBuilder.addNode(node);
}
val condScope = TFGraphMapper.getInstance().importGraph(condScopeGraphDefBuilder.build());
initWith.putSubFunction(ifNodes.getTrueBodyScopeName(),trueScope);
initWith.putSubFunction(ifNodes.getFalseBodyScopeName(),falseScope);
initWith.putSubFunction(ifNodes.getConditionBodyScopeName(),condScope);
this.loopBodyExecution = trueScope;
this.falseBodyExecution = falseScope;
this.predicateExecution = condScope;
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return Arrays.asList(LongShapeDescriptor.fromShape(new long[0], DataType.BOOL));
}
@Override
public CustomOpDescriptor getDescriptor() {
return null;
}
@Override
public void assertValidForExecution() {
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("This operation has no TF counterpart");
}
}

View File

@ -1,93 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import java.util.List;
@NoArgsConstructor
public class IfDerivative extends If {
private If ifDelegate;
public IfDerivative(If ifBlock) {
super(ifBlock);
this.ifDelegate = ifBlock;
}
@Override
public Boolean getTrueBodyExecuted() {
return ifDelegate.trueBodyExecuted;
}
@Override
public SameDiffFunctionDefinition getFalseBody() {
return ifDelegate.falseBody;
}
@Override
public SameDiff getFalseBodyExecution() {
return ifDelegate.falseBodyExecution;
}
@Override
public String getBlockName() {
return ifDelegate.blockName;
}
@Override
public String getFalseBodyName() {
return ifDelegate.falseBodyName;
}
@Override
public SameDiff getLoopBodyExecution() {
return ifDelegate.loopBodyExecution;
}
@Override
public SameDiffConditional getPredicate() {
return ifDelegate.getPredicate();
}
@Override
public SameDiff getPredicateExecution() {
return ifDelegate.predicateExecution;
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
return super.calculateOutputShape();
}
@Override
public String opName() {
return "if_bp";
}
@Override
public List<SDVariable> diff(List<SDVariable> i_v1) {
throw new UnsupportedOperationException("Unable to take the derivative of the derivative for if");
}
}

View File

@ -1,32 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.Builder;
import lombok.Data;
import org.tensorflow.framework.NodeDef;
import java.util.List;
@Builder
@Data
public class IfImportState {
private List<NodeDef> condNodes;
private List<NodeDef> trueNodes;
private List<NodeDef> falseNodes;
private String falseBodyScopeName,trueBodyScopeName,conditionBodyScopeName;
}

View File

@ -55,7 +55,7 @@ public class Select extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
} }

View File

@ -1,660 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.*;
import lombok.extern.slf4j.Slf4j;
import onnx.Onnx;
import org.nd4j.autodiff.functions.DifferentialFunction;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.CustomOp;
import org.nd4j.linalg.api.ops.CustomOpDescriptor;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.shape.LongShapeDescriptor;
import org.nd4j.linalg.exception.ND4JIllegalArgumentException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.weightinit.impl.ZeroInitScheme;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;
/**
* Equivalent to tensorflow's while loop
* Takes in:
* loopVars
* loop body
* condition
*
* runs loop till condition is false.
* @author Adam Gibson
*/
@NoArgsConstructor
@Slf4j
public class While extends DifferentialFunction implements CustomOp {
private AtomicInteger startPosition;
@Getter
protected SameDiff loopBodyExecution,predicateExecution;
@Getter
protected SameDiffConditional predicate;
@Getter
protected SameDiffFunctionDefinition trueBody;
@Getter
protected String blockName,trueBodyName;
@Getter
protected SDVariable[] inputVars;
@Getter
protected SDVariable targetBoolean;
protected SDVariable dummyResult;
@Getter
@Setter
protected SDVariable[] outputVars;
@Getter
protected int numLooped = 0;
/**
* Mainly meant for tensorflow import.
* This allows {@link #initFromTensorFlow(NodeDef, SameDiff, Map, GraphDef)}
* to continue from a parent while loop
* using the same graph
* @param startPosition the start position for the import scan
*/
public While(AtomicInteger startPosition) {
this.startPosition = startPosition;
}
public While(While whileStatement) {
this.sameDiff = whileStatement.sameDiff;
this.outputVars = whileStatement.outputVars;
this.loopBodyExecution = whileStatement.loopBodyExecution;
this.numLooped = whileStatement.numLooped;
this.dummyResult = whileStatement.dummyResult;
this.predicate = whileStatement.predicate;
this.predicateExecution = whileStatement.predicateExecution;
this.inputVars = whileStatement.inputVars;
this.dummyResult = this.sameDiff.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
}
@Builder
public While(String blockName,
SameDiff parent,
SDVariable[] inputVars,
SameDiffConditional predicate,
SameDiffFunctionDefinition condition,
SameDiffFunctionDefinition trueBody) {
init(blockName,parent,inputVars,predicate,condition,trueBody);
}
private void init(String blockName,
SameDiff parent,
SDVariable[] inputVars,
SameDiffConditional predicate,
SameDiffFunctionDefinition condition,
SameDiffFunctionDefinition trueBody) {
this.sameDiff = parent;
this.inputVars = inputVars;
this.predicate = predicate;
this.trueBody = trueBody;
this.blockName = blockName;
this.dummyResult = parent.var("dummyresult-" + UUID.randomUUID().toString(),new ZeroInitScheme('f'), DataType.FLOAT, 1);
parent.putOpForId(getOwnName(),this);
parent.addArgsFor(inputVars,this);
parent.addOutgoingFor(new SDVariable[]{dummyResult},this);
//create a samediff sub graph for running just the execution
//return a reference to the loop for referencing during actual execution
SameDiff sameDiff = SameDiff.create();
//store the reference to the result array and the same diff execution instance
this.targetBoolean = predicate.eval(sameDiff,condition, inputVars);
this.predicateExecution = sameDiff;
//store references to the loop body
String trueBodyName = "true-body-" + UUID.randomUUID().toString();
this.trueBodyName = trueBodyName;
//running define function will setup a proper same diff instance
parent.defineFunction(trueBodyName,trueBody,inputVars);
parent.defineFunction(blockName,condition,inputVars);
parent.putSubFunction("predicate-eval-body",sameDiff);
//get a reference to the actual loop body
this.loopBodyExecution = parent.getFunction(trueBodyName);
}
@Override
public SDVariable[] outputVariables(String baseName) {
return new SDVariable[]{dummyResult};
}
@Override
public List<SDVariable> doDiff(List<SDVariable> f1) {
List<SDVariable> ret = new ArrayList<>();
ret.addAll(Arrays.asList(new WhileDerivative(this).outputVariables()));
return ret;
}
/**
* Increments the loop counter.
* This should be called when the loop
* actually executes.
*/
public void incrementLoopCounter() {
numLooped++;
}
@Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
doImport(nodeDef,initWith,attributesForNode,graph,new LinkedHashSet<String>(),new AtomicInteger(0));
}
private void doImport(NodeDef nodeDef,SameDiff initWith,Map<String,AttrValue> attributesForNode,GraphDef graph,Set<String> skipSet,AtomicInteger currIndex) {
val uniqueId = java.util.UUID.randomUUID().toString();
skipSet.add(nodeDef.getName());
val scopeCondition = SameDiff.create();
val scopeLoop = SameDiff.create();
initWith.putSubFunction("condition-" + uniqueId,scopeCondition);
initWith.putSubFunction("loopbody-" + uniqueId,scopeLoop);
this.loopBodyExecution = scopeLoop;
this.predicateExecution = scopeCondition;
this.startPosition = currIndex;
log.info("Adding 2 new scopes for WHILE {}");
val nodes = graph.getNodeList();
/**
* Plan is simple:
* 1) we read all declarations of variables used within loop
* 2) we set up conditional scope
* 3) we set up body scope
* 4) ???
* 5) PROFIT!
*/
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("enter")) {
//skipSet.add(tfNode.getName());
break;
}
// if (skipSet.contains(tfNode.getName()))
// continue;
skipSet.add(tfNode.getName());
val vars = new SDVariable[tfNode.getInputCount()];
for (int e = 0; e < tfNode.getInputCount(); e++) {
val input = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(e));
vars[e] = initWith.getVariable(input) == null ? initWith.var(input, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(input);
scopeCondition.var(vars[e]);
scopeLoop.var(vars[e]);
}
this.inputVars = vars;
}
// now we're skipping Merge step, since we've already captured variables at Enter step
int mergedCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("merge")) {
scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor) null,new ZeroInitScheme());
break;
}
skipSet.add(tfNode.getName());
val var = scopeLoop.var(TFGraphMapper.getInstance().getNodeName(tfNode.getName()), (LongShapeDescriptor)null,new ZeroInitScheme());
scopeCondition.var(var);
initWith.var(var);
mergedCnt++;
}
// now, we're adding conditional scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (tfNode.getOp().equalsIgnoreCase("LoopCond")) {
skipSet.add(tfNode.getName());
currIndex.incrementAndGet();
break;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeCondition.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
scopeLoop.var(var);
initWith.var(var);
log.info("Adding condition var [{}]", var.getVarName());
}
else if(!skipSet.contains(tfNode.getName())) {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode,scopeCondition,nodeDef.getAttrMap(),graph);
func.setSameDiff(scopeLoop);
}
skipSet.add(tfNode.getName());
}
// time to skip some Switch calls
int switchCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
// we're parsing up to condition
if (!tfNode.getOp().equalsIgnoreCase("Switch"))
break;
switchCnt++;
skipSet.add(tfNode.getName());
}
// now we're parsing Identity step
int identityCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Identity")) {
break;
}
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
func.setSameDiff(scopeLoop);
val variables = new SDVariable[tfNode.getInputCount()];
for(int i = 0; i < tfNode.getInputCount(); i++) {
val testVar = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
if(testVar == null) {
variables[i] = initWith.var(tfNode.getInput(i), (LongShapeDescriptor) null,new ZeroInitScheme());
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
continue;
}
else {
variables[i] = initWith.getVariable(TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i)));
scopeCondition.var(variables[i]);
scopeLoop.var(variables[i]);
}
}
scopeLoop.addArgsFor(variables,func);
skipSet.add(tfNode.getName());
}
// parsing body scope
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
if (tfNode.getOp().equalsIgnoreCase("NextIteration")) {
// skipSet.add(tfNode.getName());
break;
}
if (skipSet.contains(tfNode.getName())) {
log.info("Skipping: {}", tfNode.getName());
continue;
}
boolean isConst = tfNode.getOp().equalsIgnoreCase("const");
boolean isVar = tfNode.getOp().startsWith("VariableV");
boolean isPlaceholder = tfNode.getOp().startsWith("Placeholder");
if (isConst || isVar || isPlaceholder) {
val var = scopeLoop.var(tfNode.getName(), (LongShapeDescriptor) null,new ZeroInitScheme());
log.info("Adding body var [{}]",var.getVarName());
} else {
log.info("starting on [{}]: {}", tfNode.getName(), tfNode.getOp());
if (tfNode.getOp().equalsIgnoreCase("enter")) {
log.info("NEW LOOP ----------------------------------------");
val func = new While(currIndex);
func.doImport(nodeDef,initWith,attributesForNode,graph,skipSet,currIndex);
func.setSameDiff(initWith);
log.info("END LOOP ----------------------------------------");
} else {
val func = DifferentialFunctionClassHolder.getInstance().getInstance(TFGraphMapper.getInstance().getMappedOp(tfNode.getOp()).opName());
func.initFromTensorFlow(tfNode,initWith,nodeDef.getAttrMap(),graph);
func.setSameDiff(scopeCondition);
val variables = new SDVariable[tfNode.getInputCount()];
for(int i = 0; i < tfNode.getInputCount(); i++) {
val name = TFGraphMapper.getInstance().getNodeName(tfNode.getInput(i));
variables[i] = scopeCondition.getVariable(name);
if(variables[i] == null) {
if(scopeLoop.getVariable(name) == null)
variables[i] = scopeCondition.var(initWith.getVariable(name));
else if(scopeLoop.getVariable(name) != null)
variables[i] = scopeLoop.getVariable(name);
else
variables[i] = scopeLoop.var(name, Nd4j.scalar(1.0));
}
}
scopeLoop.addArgsFor(variables,func);
}
}
skipSet.add(tfNode.getName());
}
val returnInputs = new ArrayList<SDVariable>();
val returnOutputs = new ArrayList<SDVariable>();
// mapping NextIterations, to Return op
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("NextIteration"))
break;
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
returnInputs.add(input);
}
this.outputVars = returnOutputs.toArray(new SDVariable[returnOutputs.size()]);
this.inputVars = returnInputs.toArray(new SDVariable[returnInputs.size()]);
initWith.addArgsFor(inputVars,this);
initWith.addOutgoingFor(outputVars,this);
// we should also map While/Exit to libnd4j while
int exitCnt = 0;
for (; currIndex.get() < nodes.size(); currIndex.incrementAndGet()) {
val tfNode = nodes.get(currIndex.get());
if (!tfNode.getOp().equalsIgnoreCase("Exit")) {
//skipSet.add(tfNode.getName());
break;
}
skipSet.add(tfNode.getName());
val inputName = TFGraphMapper.getInstance().getNodeName(tfNode.getName());
val input = initWith.getVariable(inputName) == null ? initWith.var(inputName, (LongShapeDescriptor) null,new ZeroInitScheme()) : initWith.getVariable(inputName) ;
}
//the output of the condition should always be a singular scalar
//this is a safe assumption
val conditionVars = scopeCondition.ops();
if(conditionVars.length < 1) {
throw new ND4JIllegalArgumentException("No functions found!");
}
this.targetBoolean = conditionVars[conditionVars.length - 1].outputVariables()[0];
log.info("-------------------------------------------");
}
@Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
}
@Override
public String toString() {
return opName();
}
@Override
public String opName() {
return "while";
}
@Override
public long opHash() {
return opName().hashCode();
}
@Override
public boolean isInplaceCall() {
return false;
}
@Override
public INDArray[] outputArguments() {
return new INDArray[0];
}
@Override
public INDArray[] inputArguments() {
return new INDArray[0];
}
@Override
public long[] iArgs() {
return new long[0];
}
@Override
public double[] tArgs() {
return new double[0];
}
@Override
public void addIArgument(int... arg) {
}
@Override
public void addIArgument(long... arg) {
}
@Override
public void removeIArgument(Integer arg) {
}
@Override
public Long getIArgument(int index) {
return null;
}
@Override
public int numIArguments() {
return 0;
}
@Override
public void addTArgument(double... arg) {
}
@Override
public void removeTArgument(Double arg) {
}
@Override
public Double getTArgument(int index) {
return null;
}
@Override
public int numTArguments() {
return 0;
}
@Override
public int numBArguments() {
return 0;
}
@Override
public void addInputArgument(INDArray... arg) {
}
@Override
public void removeInputArgument(INDArray arg) {
}
@Override
public boolean[] bArgs() {
return new boolean[0];
}
@Override
public void addBArgument(boolean... arg) {
}
@Override
public Boolean getBArgument(int index) {
return null;
}
@Override
public INDArray getInputArgument(int index) {
return null;
}
@Override
public int numInputArguments() {
return 0;
}
@Override
public void addOutputArgument(INDArray... arg) {
}
@Override
public void removeOutputArgument(INDArray arg) {
}
@Override
public INDArray getOutputArgument(int index) {
return null;
}
@Override
public int numOutputArguments() {
return 0;
}
@Override
public List<LongShapeDescriptor> calculateOutputShape() {
List<LongShapeDescriptor> ret = new ArrayList<>();
for(SDVariable var : args()) {
ret.add(sameDiff.getShapeDescriptorForVarName(var.getVarName()));
}
return ret;
}
@Override
public CustomOpDescriptor getDescriptor() {
return CustomOpDescriptor.builder().build();
}
@Override
public void assertValidForExecution() {
}
@Override
public String onnxName() {
throw new NoOpNameFoundException("No onnx op opName found for " + opName());
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No *singular (eg: use tensorflowNames() found for this op " + opName());
}
@Override
public String[] tensorflowNames() {
throw new NoOpNameFoundException("This operation has no TF counterpart");
}
@Override
public Op.Type opType() {
return Op.Type.LOOP;
}
}

View File

@ -1,96 +0,0 @@
/*******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
* https://www.apache.org/licenses/LICENSE-2.0.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*
* SPDX-License-Identifier: Apache-2.0
******************************************************************************/
package org.nd4j.linalg.api.ops.impl.controlflow;
import lombok.NoArgsConstructor;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.SameDiffConditional;
import org.nd4j.autodiff.samediff.SameDiffFunctionDefinition;
import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.linalg.api.ops.Op;
/**
* While loop derivative
* @author Adam Gibson
*/
@NoArgsConstructor
public class WhileDerivative extends While {
private While delegate;
public WhileDerivative(While delegate) {
super(delegate);
this.delegate = delegate;
}
@Override
public SameDiffFunctionDefinition getTrueBody() {
return delegate.trueBody;
}
@Override
public String getTrueBodyName() {
return delegate.getTrueBodyName();
}
@Override
public SameDiffConditional getPredicate() {
return delegate.getPredicate();
}
@Override
public SameDiff getPredicateExecution() {
return delegate.getPredicateExecution();
}
@Override
public SDVariable[] getInputVars() {
return delegate.getInputVars();
}
@Override
public String getBlockName() {
return delegate.getBlockName();
}
@Override
public SameDiff getLoopBodyExecution() {
return delegate.getLoopBodyExecution();
}
@Override
public int getNumLooped() {
return delegate.getNumLooped();
}
@Override
public String opName() {
return "while_bp";
}
@Override
public Op.Type opType() {
return Op.Type.CONDITIONAL;
}
@Override
public String tensorflowName() {
throw new NoOpNameFoundException("No tensorflow name for while backprop");
}
}

View File

@ -55,7 +55,7 @@ public abstract class BaseCompatOp extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode,nodeDef, graph);
} }
@Override @Override

View File

@ -32,9 +32,11 @@ import java.util.List;
import java.util.Map; import java.util.Map;
public class LoopCond extends BaseCompatOp { public class LoopCond extends BaseCompatOp {
public static final String OP_NAME = "loop_cond";
@Override @Override
public String opName() { public String opName() {
return "loop_cond"; return OP_NAME;
} }
@Override @Override

View File

@ -74,8 +74,6 @@ public class CropAndResize extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
String method = attributesForNode.get("method").getS().toStringUtf8(); String method = attributesForNode.get("method").getS().toStringUtf8();
if(method.equalsIgnoreCase("nearest")){ if(method.equalsIgnoreCase("nearest")){
this.method = Method.NEAREST; this.method = Method.NEAREST;

View File

@ -120,4 +120,10 @@ public class ExtractImagePatches extends DynamicCustomOp {
//TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3) //TF includes redundant leading and training 1s for kSizes, strides, rates (positions 0/3)
return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)}; return new int[]{(int)ilist.getI(1), (int)ilist.getI(2)};
} }
@Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0));
}
} }

View File

@ -74,7 +74,7 @@ public class ResizeBilinear extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
this.alignCorners = attributesForNode.get("align_corners").getB(); this.alignCorners = attributesForNode.get("align_corners").getB();
addArgs(); addArgs();

View File

@ -50,7 +50,7 @@ public class ResizeNearestNeighbor extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
} }
@Override @Override

View File

@ -26,8 +26,6 @@ import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.autodiff.samediff.internal.SameDiffOp; import org.nd4j.autodiff.samediff.internal.SameDiffOp;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -41,7 +39,6 @@ import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field;
import java.util.*; import java.util.*;
@ -106,7 +103,7 @@ public class BatchNorm extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
//Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta] //Switch order: TF uses [input, gamma, beta, mean, variance]; libnd4j expects [input, mean, variance, gamma, beta]
SameDiffOp op = initWith.getOps().get(this.getOwnName()); SameDiffOp op = initWith.getOps().get(this.getOwnName());
List<String> list = op.getInputsToOp(); List<String> list = op.getInputsToOp();
@ -140,8 +137,7 @@ public class BatchNorm extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
} }
@Override @Override

View File

@ -21,33 +21,20 @@ import lombok.Getter;
import lombok.NoArgsConstructor; import lombok.NoArgsConstructor;
import lombok.NonNull; import lombok.NonNull;
import lombok.extern.slf4j.Slf4j; import lombok.extern.slf4j.Slf4j;
import lombok.val;
import onnx.Onnx;
import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.base.Preconditions; import org.nd4j.base.Preconditions;
import org.nd4j.imports.NoOpNameFoundException; import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueNDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv1DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.*; import java.util.Collections;
import java.util.List;
import java.util.Map;
/** /**

View File

@ -31,7 +31,6 @@ import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.*; import org.nd4j.imports.descriptors.properties.adapters.*;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
@ -122,7 +121,7 @@ public class Conv2D extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }
@ -138,8 +137,7 @@ public class Conv2D extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
} }

View File

@ -251,7 +251,7 @@ public class Conv3D extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }

View File

@ -198,7 +198,7 @@ public class DeConv2D extends DynamicCustomOp {
val args = args(); val args = args();
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
if (arr == null) { if (arr == null) {
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
// TODO: arguable. it might be easier to permute weights once // TODO: arguable. it might be easier to permute weights once
//arr = (arr.permute(3, 2, 0, 1).dup('c')); //arr = (arr.permute(3, 2, 0, 1).dup('c'));
val varForOp = initWith.getVariable(args[1].getVarName()); val varForOp = initWith.getVariable(args[1].getVarName());

View File

@ -214,7 +214,7 @@ public class DeConv2DTF extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }
@ -240,9 +240,9 @@ public class DeConv2DTF extends DynamicCustomOp {
} }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //inShape, weights, input
int n = args().length; int n = args().length;
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(0)); return Collections.singletonList(inputDataTypes.get(2));
} }
} }

View File

@ -160,7 +160,7 @@ public class DeConv3D extends DynamicCustomOp {
val args = args(); val args = args();
INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr(); INDArray arr = sameDiff.getVariable(args[1].getVarName()).getArr();
if (arr == null) { if (arr == null) {
arr = TFGraphMapper.getInstance().getNDArrayFromTensor(nodeDef.getInput(0), nodeDef, graph); arr = TFGraphMapper.getNDArrayFromTensor(nodeDef);
val varForOp = initWith.getVariable(args[1].getVarName()); val varForOp = initWith.getVariable(args[1].getVarName());
if (arr != null) if (arr != null)
initWith.associateArrayWithVariable(arr, varForOp); initWith.associateArrayWithVariable(arr, varForOp);

View File

@ -77,7 +77,7 @@ public class DepthToSpace extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
boolean isNHWC = dataFormat.equals("NHWC"); boolean isNHWC = dataFormat.equals("NHWC");
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }

View File

@ -29,14 +29,15 @@ import org.nd4j.imports.NoOpNameFoundException;
import org.nd4j.imports.converters.DifferentialFunctionClassHolder; import org.nd4j.imports.converters.DifferentialFunctionClassHolder;
import org.nd4j.imports.descriptors.properties.AttributeAdapter; import org.nd4j.imports.descriptors.properties.AttributeAdapter;
import org.nd4j.imports.descriptors.properties.PropertyMapping; import org.nd4j.imports.descriptors.properties.PropertyMapping;
import org.nd4j.imports.descriptors.properties.adapters.*; import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter;
import org.nd4j.imports.graphmapper.onnx.OnnxGraphMapper; import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter;
import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater;
import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig;
import org.nd4j.linalg.api.ops.impl.layers.convolution.config.DeConv3DConfig;
import org.nd4j.linalg.util.ArrayUtil; import org.nd4j.linalg.util.ArrayUtil;
import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.GraphDef;
@ -136,7 +137,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
/* /*
@ -162,8 +163,7 @@ public class DepthwiseConv2D extends DynamicCustomOp {
@Override @Override
public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) { public void initFromOnnx(Onnx.NodeProto node, SameDiff initWith, Map<String, Onnx.AttributeProto> attributesForNode, Onnx.GraphProto graph) {
OnnxGraphMapper.getInstance().initFunctionFromProperties(node.getOpType(), this, attributesForNode, node, graph);
addArgs();
} }

View File

@ -75,7 +75,7 @@ public class SpaceToDepth extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC"); boolean isNHWC = dataFormat == null ? true : dataFormat.equals("NHWC");
addIArgument(blockSize, isNHWC ? 1 : 0); addIArgument(blockSize, isNHWC ? 1 : 0);
} }

View File

@ -64,7 +64,7 @@ public class SoftmaxCrossEntropyLoss extends BaseLoss {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }

View File

@ -55,7 +55,7 @@ public class SparseSoftmaxCrossEntropyLossWithLogits extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
//Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits] //Switch order: TF uses [logits, labels]; libnd4j expects [labels, logits]
SameDiffOp op = initWith.getOps().get(this.getOwnName()); SameDiffOp op = initWith.getOps().get(this.getOwnName());

View File

@ -64,7 +64,7 @@ public class Moments extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }

View File

@ -60,7 +60,7 @@ public class NormalizeMoments extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
addArgs(); addArgs();
} }

View File

@ -63,7 +63,7 @@ public class ScatterAdd extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -86,7 +86,7 @@ public class ScatterDiv extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -60,7 +60,7 @@ public class ScatterMax extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -60,7 +60,7 @@ public class ScatterMin extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -62,7 +62,7 @@ public class ScatterMul extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -67,7 +67,7 @@ public class ScatterNd extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {
@ -80,8 +80,8 @@ public class ScatterNd extends DynamicCustomOp {
} }
@Override @Override
public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ public List<DataType> calculateOutputDataTypes(List<DataType> inputDataTypes){ //Indices, updates, shape
Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 2, "Expected exactly 2 input datatypes for %s, got %s", getClass(), inputDataTypes); Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 3, "Expected exactly 3 input datatypes for %s, got %s", getClass(), inputDataTypes);
return Collections.singletonList(inputDataTypes.get(1)); return Collections.singletonList(inputDataTypes.get(1));
} }

View File

@ -66,7 +66,7 @@ public class ScatterNdAdd extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -66,7 +66,7 @@ public class ScatterNdSub extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -66,7 +66,7 @@ public class ScatterNdUpdate extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

View File

@ -79,7 +79,7 @@ public class ScatterSub extends DynamicCustomOp {
@Override @Override
public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) { public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
TFGraphMapper.getInstance().initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph); TFGraphMapper.initFunctionFromProperties(nodeDef.getOp(), this, attributesForNode, nodeDef, graph);
if (nodeDef.containsAttr("use_locking")) { if (nodeDef.containsAttr("use_locking")) {
if (nodeDef.getAttrOrThrow("use_locking").getB() == true) { if (nodeDef.getAttrOrThrow("use_locking").getB() == true) {

Some files were not shown because too many files have changed in this diff Show More